// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
)
type MockedHealthIndicator struct {
Status health.Status
Description string
Details map[string]interface{}
}
func NewMockedHealthIndicator() *MockedHealthIndicator {
return &MockedHealthIndicator{
Status: health.StatusUp,
Description: "mocked",
Details: map[string]interface{}{
"key": "value",
},
}
}
func (i *MockedHealthIndicator) Name() string {
return "test"
}
func (i *MockedHealthIndicator) Health(_ context.Context, opts health.Options) health.Health {
ret := health.CompositeHealth{
SimpleHealth: health.SimpleHealth{
Stat: i.Status,
Desc: i.Description,
},
}
if opts.ShowComponents {
detailed := health.DetailedHealth{
SimpleHealth: health.SimpleHealth{
Stat: i.Status,
Desc: "mocked detailed",
},
}
if opts.ShowDetails {
detailed.Details = i.Details
}
ret.Components = map[string]health.Health{
"simple": health.SimpleHealth{
Stat: i.Status,
Desc: "mocked simple",
},
"detailed": detailed,
}
}
return ret
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package alive
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator"
"net/http"
)
const (
ID = "alive"
EnableByDefault = true
)
type Input struct{}
type Output struct{
sc int
Message string `json:"msg"`
}
// http.StatusCoder
func (o Output) StatusCode() int {
return o.sc
}
// AliveEndpoint implements actuator.Endpoint, actuator.WebEndpoint
type AliveEndpoint struct {
actuator.WebEndpointBase
}
func new(di regDI) *AliveEndpoint {
ep := AliveEndpoint{}
ep.WebEndpointBase = actuator.MakeWebEndpointBase(func(opt *actuator.EndpointOption) {
opt.Id = ID
opt.Ops = []actuator.Operation{
actuator.NewReadOperation(ep.Read),
}
opt.Properties = &di.MgtProperties.Endpoints
opt.EnabledByDefault = EnableByDefault
})
return &ep
}
// Read never returns error
func (ep *AliveEndpoint) Read(ctx context.Context, input *Input) (Output, error) {
return Output{
sc: http.StatusOK,
Message: "I'm good",
}, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package alive
import (
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "actuator-alive",
Precedence: actuator.MinActuatorPrecedence,
Options: []fx.Option{
fx.Invoke(register),
},
}
func Register() {
bootstrap.Register(Module)
}
type regDI struct {
fx.In
Registrar *actuator.Registrar
MgtProperties actuator.ManagementProperties
}
func register(di regDI) {
ep := new(di)
di.Registrar.MustRegister(ep)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package apilist
import (
"context"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/web"
"io/fs"
"net/http"
)
const (
ID = "apilist"
EnableByDefault = false
)
// ApiListEndpoint implements actuator.Endpoint, actuator.WebEndpoint
//goland:noinspection GoNameStartsWithPackageName
type ApiListEndpoint struct {
actuator.WebEndpointBase
staticPath string
}
func newEndpoint(di regDI) *ApiListEndpoint {
if !fs.ValidPath(di.Properties.StaticPath) {
panic("invalid static-path for apilist endpoint")
}
ep := ApiListEndpoint{
staticPath: di.Properties.StaticPath,
}
ep.WebEndpointBase = actuator.MakeWebEndpointBase(func(opt *actuator.EndpointOption) {
opt.Id = ID
opt.Ops = []actuator.Operation{
actuator.NewReadOperation(ep.Read),
}
opt.Properties = &di.MgtProperties.Endpoints
opt.EnabledByDefault = EnableByDefault
})
return &ep
}
// Read never returns error
func (ep *ApiListEndpoint) Read(ctx context.Context, _ *struct{}) (interface{}, error) {
resp, e := parseFromStaticFile(ep.staticPath)
if e != nil {
// Note we don't expose error. Instead, we return 404 like nothing is there
logger.WithContext(ctx).Warnf(`unable to load static API list file "%s": %v`, ep.staticPath, e)
return nil, web.NewHttpError(http.StatusNotFound, fmt.Errorf("APIList is not available"))
}
return resp, nil
}
func parseFromStaticFile(path string) (ret interface{}, err error) {
// open
var file fs.File
var e error
for _, fsys := range staticFS {
if file, e = fsys.Open(path); e == nil {
break
}
}
if e != nil {
return nil, e
}
// read
defer func(){ _ = file.Close() }()
decoder := json.NewDecoder(file)
if e := decoder.Decode(&ret); e != nil {
return nil, e
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package apilist
import (
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
"io/fs"
"os"
)
var logger = log.New("ACTR.APIList")
var staticFS = []fs.FS{os.DirFS(".")}
var Module = &bootstrap.Module{
Name: "actuator-apilist",
Precedence: actuator.MinActuatorPrecedence,
Options: []fx.Option{
fx.Provide(BindProperties),
fx.Invoke(register),
},
}
func Register() {
bootstrap.Register(Module)
}
func StaticFS(fs ...fs.FS) {
if len(fs) != 0 {
staticFS = fs
}
}
type regDI struct {
fx.In
Registrar *actuator.Registrar
MgtProperties actuator.ManagementProperties
Properties Properties
}
func register(di regDI) {
ep := newEndpoint(di)
di.Registrar.MustRegister(ep)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package apilist
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/pkg/errors"
)
const (
PropertiesPrefix = "management.endpoint.apilist"
)
type FSType int
type Properties struct {
StaticPath string `json:"static-path"`
}
//NewProperties create a Properties with default values
func NewProperties() *Properties {
return &Properties{
StaticPath: "configs/api-list.json",
}
}
//BindProperties create and bind SessionProperties, with a optional prefix
func BindProperties(ctx *bootstrap.ApplicationContext) Properties {
props := NewProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind Properties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuator
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/rest"
"net/http"
"reflect"
"strings"
)
var (
ctxType = reflect.TypeOf((*context.Context)(nil)).Elem()
errorType = reflect.TypeOf((*error)(nil)).Elem()
)
/*******************************
Operation
********************************/
// operation implements Operation, and hold some metadata with reflection
type operation struct {
mode OperationMode
f OperationFunc
matcher matcher.Matcher
function reflect.Value
input reflect.Type
output reflect.Type
}
func newOperation(mode OperationMode, opFunc OperationFunc, inputMatchers ...matcher.Matcher) *operation {
var m matcher.Matcher
switch len(inputMatchers) {
case 0:
// do nothing
case 1:
m = inputMatchers[0]
default:
m = matcher.Or(inputMatchers[0], inputMatchers[1:]...)
}
op := operation{
mode: mode,
f: opFunc,
matcher: m,
}
if e := populateOperationMetadata(opFunc, &op); e != nil {
panic(e)
}
return &op
}
func NewReadOperation(opFunc OperationFunc, inputMatchers ...matcher.Matcher) Operation {
return newOperation(OperationRead, opFunc, inputMatchers...)
}
func NewWriteOperation(opFunc OperationFunc, inputMatchers ...matcher.Matcher) Operation {
return newOperation(OperationWrite, opFunc, inputMatchers...)
}
func (op operation) Mode() OperationMode {
return op.mode
}
func (op operation) Func() OperationFunc {
return op.f
}
func (op operation) Matches(ctx context.Context, mode OperationMode, input interface{}) bool {
if mode != op.mode || !reflect.TypeOf(input).ConvertibleTo(op.input) {
return false
}
if op.matcher == nil {
return true
}
m, e := op.matcher.MatchesWithContext(ctx, input)
return e != nil || !m
}
func (op operation) Execute(ctx context.Context, input interface{}) (interface{}, error) {
in := reflect.ValueOf(input).Convert(op.input)
ret := op.function.Call([]reflect.Value{reflect.ValueOf(ctx), in})
switch len(ret) {
case 1:
return nil, ret[0].Interface().(error)
case 2:
return ret[0].Interface(), ret[1].Interface().(error)
default:
// find error param
for _, v := range ret {
if e, ok := v.Interface().(error); ok {
return nil, e
}
}
return nil, fmt.Errorf("operation failed with unknown error")
}
}
func populateOperationMetadata(opFunc OperationFunc, op *operation) error {
op.function = reflect.ValueOf(opFunc)
if e := validateFunc(op.function); e != nil {
return e
}
t := op.function.Type()
op.input = t.In(t.NumIn() - 1)
if t.NumOut() > 1 {
op.output = t.Out(0)
}
return nil
}
func validateFunc(f reflect.Value) error {
// since golang doesn't have generics, we have to check the signature at run-time
if f.Kind() != reflect.Func {
return fmt.Errorf("OperationFunc must be a function but got %T", f.Interface())
//return fmt.Errorf("OperationFunc must have signigure 'func(ctx context.Context, input Type1) (Type2, error)', but got %v", f.Interface())
}
t := f.Type()
switch {
// input params validation
case t.NumIn() != 2:
fallthrough
case !t.In(0).Implements(ctxType):
fallthrough
case !isStructOrPtrToStruct(t.In(1)):
fallthrough
// Out params validation
case t.NumOut() < 1 || t.NumOut() > 2:
fallthrough
case !t.Out(t.NumOut() - 1).Implements(errorType):
fallthrough
case t.NumOut() == 2 && !isSupportedOutputType(t.Out(0)):
return invalidOpFuncSignatureError(f.Interface())
}
return nil
}
func isStructOrPtrToStruct(t reflect.Type) (ret bool) {
ret = t.Kind() == reflect.Struct
ret = ret || t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
return
}
func isSupportedOutputType(t reflect.Type) bool {
if isStructOrPtrToStruct(t) {
return true
}
switch t.Kind() {
case reflect.Interface:
fallthrough
case reflect.Map:
return true
default:
return false
}
}
func invalidOpFuncSignatureError(f interface{}) error {
return fmt.Errorf("OperationFunc must have signigure 'func(ctx context.Context, input Type1) (Type2, error)', but got %v", f)
}
/*******************************
EndpointBase
********************************/
// EndpointBase implements EndpointExecutor and partially Endpoint, and can be embedded into any Endpoint implementation
// it calls Operations using reflect
type EndpointBase struct {
id string
operations []Operation
properties *EndpointsProperties
enabledByDefault bool
}
type EndpointOptions func(opt *EndpointOption)
type EndpointOption struct {
Id string
Ops []Operation
Properties *EndpointsProperties
EnabledByDefault bool
}
func MakeEndpointBase(opts ...EndpointOptions) EndpointBase {
opt := EndpointOption{}
for _, f := range opts {
f(&opt)
}
return EndpointBase{
id: opt.Id,
operations: opt.Ops,
properties: opt.Properties,
enabledByDefault: opt.EnabledByDefault,
}
}
func (b EndpointBase) Id() string {
return b.id
}
func (b EndpointBase) EnabledByDefault() bool {
return b.enabledByDefault
}
func (b EndpointBase) Operations() []Operation {
return b.operations
}
func (b EndpointBase) ReadOperation(ctx context.Context, input interface{}) (interface{}, error) {
for _, op := range b.operations {
if op.Matches(ctx, OperationRead, input) {
return op.Execute(ctx, input)
}
}
return nil, fmt.Errorf("unsupported read operation [%s] with input [%v]", b.Id(), input)
}
func (b EndpointBase) WriteOperation(ctx context.Context, input interface{}) (interface{}, error) {
for _, op := range b.operations {
if op.Matches(ctx, OperationWrite, input) {
return op.Execute(ctx, input)
}
}
return nil, fmt.Errorf("unsupported write operation [%s] with input [%v]", b.Id(), input)
}
/*******************************
WebEndpointBase
********************************/
type MappingPathFunc func(op Operation, props *WebEndpointsProperties) string
type MappingNameFunc func(op Operation) string
// WebEndpointBase is similar to EndpointBase and implements default WebEndpoint mapping
type WebEndpointBase struct {
EndpointBase
properties *WebEndpointsProperties
formats map[string]web.EncodeResponseFunc
}
func MakeWebEndpointBase(opts ...EndpointOptions) WebEndpointBase {
base := MakeEndpointBase(opts...)
return WebEndpointBase{
EndpointBase: base,
properties: &base.properties.Web,
formats: map[string]web.EncodeResponseFunc{
ContentTypeSpringBootV3: SpringBootRespEncoderV3(),
ContentTypeSpringBootV2: SpringBootRespEncoderV2(),
"application/json": web.JsonResponseEncoder(),
},
}
}
// Mappings implements WebEndpoint
func (b WebEndpointBase) Mappings(op Operation, group string) ([]web.Mapping, error) {
builder, e := b.RestMappingBuilder(op, group, b.MappingPath, b.MappingName)
if e != nil {
return nil, e
}
return []web.Mapping{builder.Build()}, nil
}
func (b WebEndpointBase) MappingPath(_ Operation, props *WebEndpointsProperties) string {
base := strings.Trim(props.BasePath, "/")
path, ok := props.Mappings[b.id]
if !ok {
path = strings.ToLower(b.id)
}
path = strings.Trim(path, "/")
return fmt.Sprintf("/%s/%s", base, path)
}
func (b WebEndpointBase) MappingName(op Operation) string {
switch op.Mode() {
case OperationRead:
return fmt.Sprintf("%s GET", strings.ToLower(b.id))
case OperationWrite:
return fmt.Sprintf("%s POST", strings.ToLower(b.id))
default:
return ""
}
}
func (b WebEndpointBase) RestMappingBuilder(op Operation, group string,
pathFunc MappingPathFunc, nameFunc MappingNameFunc) (*rest.MappingBuilder, error) {
// NOTE: our current web implementation don't support different context-path (group)
if group != "" {
return nil, fmt.Errorf("adding actuator endpoints to different context-path/group is not supported at the moment")
}
path := pathFunc(op, b.properties)
name := nameFunc(op)
builder := rest.New(name).
Path(path).
EndpointFunc(op.Func()).
EncodeResponseFunc(b.NegotiableResponseEncoder())
switch op.Mode() {
case OperationRead:
return builder.Method(http.MethodGet), nil
case OperationWrite:
return builder.Method(http.MethodPost), nil
default:
return nil, fmt.Errorf("unsupported operation mode")
}
}
func (b WebEndpointBase) NegotiateFormat(ctx context.Context) string {
gc := web.GinContext(ctx)
if gc != nil {
if f := gc.NegotiateFormat(ContentTypeSpringBootV3, ContentTypeSpringBootV2); f != "" {
return f
}
}
return ContentTypeSpringBootV3
}
func (b WebEndpointBase) NegotiableResponseEncoder() web.EncodeResponseFunc {
return func(ctx context.Context, rw http.ResponseWriter, i interface{}) error {
format := b.NegotiateFormat(ctx)
if enc, ok := b.formats[format]; ok {
return enc(ctx, rw, i)
}
return web.JsonResponseEncoder()(ctx, rw, i)
}
}
func SpringBootRespEncoderV3() web.EncodeResponseFunc {
return web.CustomResponseEncoder(func(opt *web.EncodeOption) {
opt.ContentType = ContentTypeSpringBootV3
opt.WriteFunc = web.JsonWriteFunc
})
}
func SpringBootRespEncoderV2() web.EncodeResponseFunc {
return web.CustomResponseEncoder(func(opt *web.EncodeOption) {
opt.ContentType = ContentTypeSpringBootV2
opt.WriteFunc = web.JsonWriteFunc
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuator
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web"
)
const (
OperationRead OperationMode = iota
OperationWrite
)
type OperationMode int
// OperationFunc is a func that have following signature:
// func(ctx context.Context, input StructsOrPointerType1) (StructsOrPointerType2, error)
// where
// - StructsOrPointerType1 and StructsOrPointerType2 can be any structs or struct pointers
// - input might be ignored by particular Endpoint impl.
// - 1st output is optional for "write" operations
//
// Note: golang doesn't have generics yet...
type OperationFunc interface{}
type Operation interface {
Mode() OperationMode
Func() OperationFunc
Matches(ctx context.Context, mode OperationMode, input interface{}) bool
Execute(ctx context.Context, input interface{}) (interface{}, error)
}
type Endpoint interface {
Id() string
EnabledByDefault() bool
Operations() []Operation
}
type WebEndpoint interface {
Mappings(op Operation, group string) ([]web.Mapping, error)
}
type EndpointExecutor interface {
ReadOperation(ctx context.Context, input interface{}) (interface{}, error)
WriteOperation(ctx context.Context, input interface{}) (interface{}, error)
}
type WebEndpoints map[string][]web.Mapping
func (w WebEndpoints) EndpointIDs() (ret []string) {
ret = make([]string, 0, len(w))
for k, _ := range w {
ret = append(ret, k)
}
return
}
// Paths returns all path patterns of given endpoint ID.
// only web.RoutedMapping & web.StaticMapping is possible to extract this information
func (w WebEndpoints) Paths(id string) []string {
mappings, ok := w[id]
if !ok {
return []string{}
}
paths := utils.NewStringSet()
for _, v := range mappings {
switch v.(type) {
case web.RoutedMapping:
paths.Add(v.(web.RoutedMapping).Path())
case web.StaticMapping:
paths.Add(v.(web.StaticMapping).Path())
}
}
return paths.Values()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package env
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"sort"
)
const (
ID = "env"
EnableByDefault = false
)
type Input struct {
Pattern string `form:"match"`
}
type EnvDescriptor struct {
ActiveProfiles []string `json:"activeProfiles,omitempty"`
PropertySources []PSourceDescriptor `json:"propertySources,omitempty"`
}
type PSourceDescriptor struct {
Name string `json:"name"`
Properties map[string]PValueDescriptor `json:"properties,omitempty"`
order int
}
type PValueDescriptor struct {
Value interface{} `json:"value,omitempty"`
Origin string `json:"origin,omitempty"`
}
// EnvEndpoint implements actuator.Endpoint, actuator.WebEndpoint
type EnvEndpoint struct {
actuator.WebEndpointBase
appConfig appconfig.ConfigAccessor
sanitizer *Sanitizer
}
func new(di regDI) *EnvEndpoint {
ep := EnvEndpoint{
appConfig: di.AppContext.Config().(appconfig.ConfigAccessor),
sanitizer: NewSanitizer(di.Properties.KeysToSanitize.Values()),
}
ep.WebEndpointBase = actuator.MakeWebEndpointBase(func(opt *actuator.EndpointOption) {
opt.Id = ID
opt.Ops = []actuator.Operation{
actuator.NewReadOperation(ep.Read),
}
opt.Properties = &di.MgtProperties.Endpoints
opt.EnabledByDefault = EnableByDefault
})
return &ep
}
// Read never returns error
func (ep *EnvEndpoint) Read(ctx context.Context, input *Input) (*EnvDescriptor, error) {
// TODO maybe support match pattern
env := EnvDescriptor{
ActiveProfiles: ep.appConfig.Profiles(),
PropertySources: []PSourceDescriptor{},
}
for _, provider := range ep.appConfig.Providers() {
if !provider.IsLoaded() {
continue
}
psrc := PSourceDescriptor{
Name: provider.Name(),
Properties: map[string]PValueDescriptor{},
order: provider.Order(),
}
values := provider.GetSettings()
_ = appconfig.VisitEach(values, func(k string, v interface{}) error {
v = ep.sanitizer.Sanitize(ctx, k, v)
psrc.Properties[k] = PValueDescriptor{Value: v, Origin: ""}
return nil
})
if len(psrc.Properties) > 0 {
env.PropertySources = append(env.PropertySources, psrc)
}
}
sort.SliceStable(env.PropertySources, func(i, j int) bool {
return env.PropertySources[i].order < env.PropertySources[j].order
})
return &env, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package env
import (
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "actuator-env",
Precedence: actuator.MinActuatorPrecedence,
Options: []fx.Option{
fx.Provide(BindEnvProperties),
fx.Invoke(register),
},
}
func Register() {
bootstrap.Register(Module)
}
type regDI struct {
fx.In
Registrar *actuator.Registrar
MgtProperties actuator.ManagementProperties
AppContext *bootstrap.ApplicationContext
Properties EnvProperties
}
func register(di regDI) {
ep := new(di)
di.Registrar.MustRegister(ep)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package env
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
)
const (
EnvPropertiesPrefix = "management.endpoint.env"
)
type EnvProperties struct {
// KeysToSanitize holds list of regular expressions
KeysToSanitize utils.StringSet `json:"keys-to-sanitize"`
}
//NewSessionProperties create a SessionProperties with default values
func NewEnvProperties() *EnvProperties {
return &EnvProperties{
KeysToSanitize: utils.NewStringSet(
`.*password.*`, `.*secret.*`, `key`,
`.*credentials.*`, `vcap_services`, `sun.java.command`,
),
}
}
//BindHealthProperties create and bind SessionProperties, with a optional prefix
func BindEnvProperties(ctx *bootstrap.ApplicationContext) EnvProperties {
props := NewEnvProperties()
if err := ctx.Config().Bind(props, EnvPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind EnvProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package env
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"regexp"
"strings"
)
const (
regexChars = "*$^+"
)
var (
DefaultKeysToSanitize = utils.NewStringSet(
`.*password.*`, `.*secret.*`,
`key`, `.*credentials.*`,
`vcap_services`, `sun.java.command`,
)
)
type Sanitizer struct {
keyMatcher matcher.StringMatcher
}
func NewSanitizer(keyPatterns []string) *Sanitizer {
patterns := DefaultKeysToSanitize.Copy().Add(keyPatterns...)
var keyMatcher matcher.StringMatcher
for p, _ := range patterns {
var m matcher.StringMatcher
if isRegex(p) {
regex := regexp.MustCompile(p)
m = matcher.WithRegexPattern(regex)
} else {
m = matcher.WithString(p, false).Or(matcher.WithSuffix(p, false))
}
if keyMatcher == nil {
keyMatcher = m
} else {
keyMatcher = keyMatcher.Or(m)
}
}
return &Sanitizer{
keyMatcher: keyMatcher,
}
}
func (s Sanitizer) Sanitize(ctx context.Context, key string, value interface{}) interface{} {
// 1. can we sanitize?
switch value.(type) {
case string, []string, utils.StringSet:
default:
return value
}
// 2. does key match?
if ok, e := s.keyMatcher.MatchesWithContext(ctx, key); e != nil || !ok {
return value
}
return "********"
}
func isRegex(s string) bool {
return strings.ContainsAny(s, regexChars)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package health
import (
"context"
"math"
)
/*******************************
SimpleStatusAggregator
********************************/
var (
DefaultStatusOrders = []Status{
StatusDown, StatusOutOfService, StatusUp, StatusDown,
}
)
type AggregateOptions func(opt *AggregateOption)
type AggregateOption struct {
StatusOrders []Status
}
// SimpleStatusAggregator implements StatusAggregator
type SimpleStatusAggregator struct {
orders map[Status]int
}
func NewSimpleStatusAggregator(opts ...AggregateOptions) *SimpleStatusAggregator {
opt := AggregateOption{
StatusOrders: DefaultStatusOrders,
}
for _, f := range opts {
f(&opt)
}
orders := map[Status]int{}
for i, s := range opt.StatusOrders {
orders[s] = i
}
return &SimpleStatusAggregator{
orders: orders,
}
}
func (a SimpleStatusAggregator) Aggregate(_ context.Context, statuses ...Status) Status {
var status Status
unknown := true
for _, s := range statuses {
if unknown || a.compare(s, status) < 0 {
unknown = false
status = s
}
}
if unknown {
return StatusUnknown
}
return status
}
func (a SimpleStatusAggregator) compare(s1, s2 Status) int {
o1, ok := a.orders[s1]
if !ok {
o1 = math.MaxInt64
}
o2, ok := a.orders[s2]
if !ok {
o2 = math.MaxInt64
}
switch {
case o1 < o2:
return -1
case o1 > o2:
return 1
default:
return 0
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package health
import (
"context"
"net/http"
)
/*******************************
StaticStatusCodeMapper
********************************/
var DefaultStaticStatusCodeMapper = StaticStatusCodeMapper{
StatusUp: http.StatusOK,
StatusDown: http.StatusServiceUnavailable,
StatusOutOfService: http.StatusServiceUnavailable,
StatusUnknown: http.StatusInternalServerError,
}
type StaticStatusCodeMapper map[Status]int
func (m StaticStatusCodeMapper) StatusCode(_ context.Context, status Status) int {
if sc, ok := m[status]; ok {
return sc
}
return http.StatusServiceUnavailable
}
/*******************************
SimpleHealth
********************************/
// SimpleHealth implements Health
type SimpleHealth struct {
Stat Status `json:"status"`
Desc string `json:"description,omitempty"`
}
func (h SimpleHealth) Status() Status {
return h.Stat
}
func (h SimpleHealth) Description() string {
return h.Desc
}
/*******************************
Composite
********************************/
// CompositeHealth implement Health
type CompositeHealth struct {
SimpleHealth
Components map[string]Health `json:"components,omitempty"`
}
func NewCompositeHealth(status Status, description string, components map[string]Health) *CompositeHealth {
return &CompositeHealth{
SimpleHealth: SimpleHealth{
Stat: status,
Desc: description,
},
Components: components,
}
}
/*******************************
DetailedHealth
********************************/
// DetailedHealth implement Health
type DetailedHealth struct {
SimpleHealth
Details map[string]interface{} `json:"details,omitempty"`
}
func NewDetailedHealth(status Status, description string, details map[string]interface{}) *DetailedHealth {
return &DetailedHealth{
SimpleHealth: SimpleHealth{
Stat: status,
Desc: description,
},
Details: details,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package health
import (
"context"
"strings"
)
const (
StatusUnknown Status = iota
StatusUp
StatusOutOfService
StatusDown
)
type Status int
// fmt.Stringer
func (s Status) String() string {
switch s {
case StatusUp:
return "UP"
case StatusDown:
return "DOWN"
case StatusOutOfService:
return "OUT_OF_SERVICE"
default:
return "UNKNOWN"
}
}
// MarshalText implements encoding.TextMarshaler
func (s Status) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
//UnmarshalText implements encoding.TextUnmarshaler
func (s *Status) UnmarshalText(data []byte) error {
value := strings.ToUpper(string(data))
switch value {
case "UP":
*s = StatusUp
case "DOWN":
*s = StatusDown
case "OUT_OF_SERVICE":
*s = StatusOutOfService
default:
*s = StatusUnknown
}
return nil
}
const (
// ShowModeNever Never show the item in the response.
ShowModeNever ShowMode = iota
// ShowModeAuthorized Show the item in the response when accessed by an authorized user.
ShowModeAuthorized
// ShowModeAlways Always show the item in the response.
ShowModeAlways
// ShowModeCustom Shows the item in response with a customized rule.
ShowModeCustom
)
// ShowMode is options for showing items in responses from the HealthEndpoint web extensions.
type ShowMode int
// fmt.Stringer
func (m ShowMode) String() string {
switch m {
case ShowModeAuthorized:
return "authorized"
case ShowModeAlways:
return "always"
case ShowModeCustom:
return "custom"
default:
return "never"
}
}
// MarshalText implements encoding.TextMarshaler
func (m ShowMode) MarshalText() ([]byte, error) {
return []byte(m.String()), nil
}
// UnmarshalText implements encoding.TextUnmarshaler
func (m *ShowMode) UnmarshalText(data []byte) error {
value := strings.ToLower(string(data))
switch value {
case "authorized", "when_authorized", "whenAuthorized", "when-authorized":
*m = ShowModeAuthorized
case "always":
*m = ShowModeAlways
case "custom":
*m = ShowModeCustom
default:
*m = ShowModeNever
}
return nil
}
type Registrar interface {
// Register configure SystemHealthRegistrar and HealthEndpoint
// supported input parameters are:
// - Indicator
// - StatusAggregator
// - DetailsDisclosureControl
// - ComponentsDisclosureControl
// - DisclosureControl
Register(items ...interface{}) error
// MustRegister same as Register, but panic if there is error
MustRegister(items ...interface{})
}
type StatusAggregator interface {
Aggregate(context.Context, ...Status) Status
}
type StatusCodeMapper interface {
StatusCode(context.Context, Status) int
}
type Health interface {
Status() Status
Description() string
}
type Options struct {
ShowDetails bool
ShowComponents bool
}
type Indicator interface {
Name() string
Health(context.Context, Options) Health
}
type DetailsDisclosureControl interface {
ShouldShowDetails(ctx context.Context) bool
}
type ComponentsDisclosureControl interface {
ShouldShowComponents(ctx context.Context) bool
}
type DisclosureControl interface {
DetailsDisclosureControl
ComponentsDisclosureControl
}
// DisclosureControlFunc convert function to DisclosureControl
// This type can be registered via Registrar.Register
type DisclosureControlFunc func(ctx context.Context) bool
func (fn DisclosureControlFunc) ShouldShowDetails(ctx context.Context) bool {
return fn(ctx)
}
func (fn DisclosureControlFunc) ShouldShowComponents(ctx context.Context) bool {
return fn(ctx)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package healthep
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
)
// DefaultDisclosureControl implements health.DetailsDisclosureControl and health.ComponentsDisclosureControl
type DefaultDisclosureControl struct {
showDetails health.ShowMode
showComponents health.ShowMode
permissions utils.StringSet
detailsCtrlDelegate health.DetailsDisclosureControl
compsCtrlDelegate health.ComponentsDisclosureControl
}
func newDefaultDisclosureControl(props *health.HealthProperties,
detailsDelegate health.DetailsDisclosureControl,
compsDelegate health.ComponentsDisclosureControl) (*DefaultDisclosureControl, error) {
showComponents := props.ShowDetails
if props.ShowComponents != nil {
showComponents = *props.ShowComponents
}
// check some errors
switch {
case props.ShowDetails == health.ShowModeCustom && detailsDelegate == nil:
return nil, errors.New(`health details control is set to custom but there is no health.ComponentsDisclosureControl configured`)
case showComponents == health.ShowModeCustom && compsDelegate == nil:
return nil, errors.New(`health components control is set to custom but there is no health.DetailsDisclosureControl configured`)
}
return &DefaultDisclosureControl{
showDetails: props.ShowDetails,
showComponents: showComponents,
permissions: utils.NewStringSet(props.Permissions...),
detailsCtrlDelegate: detailsDelegate,
compsCtrlDelegate: compsDelegate,
}, nil
}
func (c *DefaultDisclosureControl) ShouldShowDetails(ctx context.Context) bool {
switch c.showDetails {
case health.ShowModeNever:
return false
case health.ShowModeAlways:
return true
case health.ShowModeAuthorized:
return c.isAuthorized(ctx)
default:
return c.detailsCtrlDelegate.ShouldShowDetails(ctx)
}
}
func (c *DefaultDisclosureControl) ShouldShowComponents(ctx context.Context) bool {
switch c.showComponents {
case health.ShowModeNever:
return false
case health.ShowModeAlways:
return true
case health.ShowModeAuthorized:
return c.isAuthorized(ctx)
default:
return c.compsCtrlDelegate.ShouldShowComponents(ctx)
}
}
func (c *DefaultDisclosureControl) isAuthorized(ctx context.Context) bool {
auth := security.Get(ctx)
if auth.State() < security.StateAuthenticated || auth.Permissions() == nil {
return false
}
for p, _ := range c.permissions {
if _, ok := auth.Permissions()[p]; !ok {
return false
}
}
return true
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package healthep
import (
"context"
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
)
const (
ID = "health"
EnableByDefault = true
)
type Input struct{}
type Output struct {
health.Health
sc int
}
type CompositeHealthV2 struct {
health.SimpleHealth
Components map[string]health.Health `json:"details,omitempty"`
}
// StatusCode http.StatusCoder
func (o Output) StatusCode() int {
return o.sc
}
// Body web.BodyContainer
func (o Output) Body() interface{} {
return o.Health
}
// MarshalJSON json.Marshaler
func (o Output) MarshalJSON() ([]byte, error) {
return json.Marshal(o.Health)
}
type EndpointOptions func(opt *EndpointOption)
type EndpointOption struct {
Contributor health.Indicator
StatusCodeMapper health.StatusCodeMapper
MgtProperties actuator.ManagementProperties
Properties health.HealthProperties
DetailsControl health.DetailsDisclosureControl
ComponentsControl health.ComponentsDisclosureControl
}
// HealthEndpoint implements actuator.Endpoint, actuator.WebEndpoint
type HealthEndpoint struct {
actuator.WebEndpointBase
contributor health.Indicator
scMapper health.StatusCodeMapper
detailsControl health.DetailsDisclosureControl
componentsControl health.ComponentsDisclosureControl
}
func newEndpoint(opts ...EndpointOptions) (*HealthEndpoint, error) {
opt := EndpointOption{}
for _, f := range opts {
f(&opt)
}
if opt.StatusCodeMapper == nil {
scMapper := health.DefaultStaticStatusCodeMapper
for k, v := range opt.Properties.Status.ScMapping {
scMapper[k] = v
}
opt.StatusCodeMapper = scMapper
}
disclosureCtrl, e := newDefaultDisclosureControl(&opt.Properties, opt.DetailsControl, opt.ComponentsControl)
if e != nil {
return nil, e
}
ep := HealthEndpoint{
contributor: opt.Contributor,
scMapper: opt.StatusCodeMapper,
detailsControl: disclosureCtrl,
componentsControl: disclosureCtrl,
}
properties := opt.MgtProperties
ep.WebEndpointBase = actuator.MakeWebEndpointBase(func(opt *actuator.EndpointOption) {
opt.Id = ID
opt.Ops = []actuator.Operation{
actuator.NewReadOperation(ep.Read),
}
opt.Properties = &properties.Endpoints
opt.EnabledByDefault = EnableByDefault
})
return &ep, nil
}
// Read never returns error
func (ep *HealthEndpoint) Read(ctx context.Context, _ *Input) (*Output, error) {
opts := health.Options{
ShowDetails: ep.detailsControl.ShouldShowDetails(ctx),
ShowComponents: ep.componentsControl.ShouldShowComponents(ctx),
}
h := ep.contributor.Health(ctx, opts)
switch f := ep.WebEndpointBase.NegotiateFormat(ctx); f {
case actuator.ContentTypeSpringBootV2:
h = ep.toSpringBootV2(h)
}
// Note: we know that *SystemHealthInitializer respect options (as all CompositeIndicator)
// we don't need to sanitize result
return &Output{
Health: h,
sc: ep.scMapper.StatusCode(ctx, h.Status()),
}, nil
}
func (ep *HealthEndpoint) toSpringBootV2(h health.Health) health.Health {
var composite *health.CompositeHealth
switch v := h.(type) {
case health.CompositeHealth:
composite = &v
case *health.CompositeHealth:
composite = v
default:
return h
}
ret := CompositeHealthV2{
SimpleHealth: composite.SimpleHealth,
Components: make(map[string]health.Health),
}
// recursively convert components
for k, v := range composite.Components {
ret.Components[k] = ep.toSpringBootV2(v)
}
return ret
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package healthep
// Contains implementation of health endpoint as a separate package to avoid cyclic package dependency.
//
// Implementations in this package cannot be moved to package "actuator/health", otherwise, it could create
// cyclic package dependency as following:
// actuator/health -> actuator -> security -> tenancy -> redis -> actuator/health
//
// Therefore, any implementations involves package mentioned above should be moved here
package healthep
import (
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "actuator-health-ep",
Precedence: actuator.MinActuatorPrecedence,
Options: []fx.Option{
fx.Invoke(register),
},
}
func Register() {
health.Use()
bootstrap.Register(Module)
}
type regDI struct {
fx.In
Properties health.HealthProperties
HealthRegistrar health.Registrar
Registrar *actuator.Registrar `optional:"true"`
MgtProperties actuator.ManagementProperties `optional:"true"`
}
func register(di regDI) {
// Note: when actuator.Registrar is nil, we don't need to anything
if di.Registrar == nil {
return
}
healthReg := di.HealthRegistrar.(*health.SystemHealthRegistrar)
endpoint, e := newEndpoint(func(opt *EndpointOption) {
opt.MgtProperties = di.MgtProperties
opt.Contributor = healthReg.Indicator
opt.Properties = di.Properties
opt.DetailsControl = healthReg.DetailsDisclosure
opt.ComponentsControl = healthReg.ComponentsDisclosure
})
if e != nil {
panic(e)
}
di.Registrar.MustRegister(endpoint)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package health
import "context"
/*******************************
CompositeIndicator
********************************/
type IndicatorOptions func(opt *IndicatorOption)
type IndicatorOption struct {
Name string
Contributors []Indicator
Aggregator StatusAggregator
}
// CompositeIndicator implement Indicator and SystemHealthRegistrar
type CompositeIndicator struct {
name string
delegates []Indicator
aggregator StatusAggregator
}
func NewCompositeIndicator(opts ...IndicatorOptions) *CompositeIndicator {
opt := IndicatorOption{
Contributors: []Indicator{},
Aggregator: NewSimpleStatusAggregator(),
}
for _, f := range opts {
f(&opt)
}
return &CompositeIndicator{
name: opt.Name,
delegates: opt.Contributors,
aggregator: opt.Aggregator,
}
}
func (c *CompositeIndicator) Add(contributors ...Indicator) {
c.delegates = append(c.delegates, contributors...)
}
func (c *CompositeIndicator) Name() string {
return c.name
}
func (c *CompositeIndicator) Health(ctx context.Context, options Options) Health {
components := map[string]Health{}
statuses := []Status{}
for _, d := range c.delegates {
h := d.Health(ctx, options)
// although delegates should respect options, we don't want to leave any changes
h = trySanitize(h, options, false)
if options.ShowComponents {
components[d.Name()] = h
}
statuses = append(statuses, h.Status())
}
status := c.aggregator.Aggregate(ctx, statuses...)
return NewCompositeHealth(status, "", components)
}
/*******************************
helpers
********************************/
func trySanitize(health Health, opts Options, deep bool) Health {
// sanitize components
if !opts.ShowComponents {
health = sanitizeComponents(health)
}
// sanitize details
if !opts.ShowDetails {
health = sanitizeDetails(health, deep)
}
return health
}
func sanitizeComponents(health Health) Health {
// sanitize components
switch health.(type) {
case *CompositeHealth:
health.(*CompositeHealth).Components = map[string]Health{}
case CompositeHealth:
return NewCompositeHealth(health.Status(), health.Description(), map[string]Health{})
}
return health
}
// recursively clean up details if deep == true
func sanitizeDetails(health Health, deep bool) Health {
// sanitize details
switch health.(type) {
case *DetailedHealth:
health.(*DetailedHealth).Details = nil
case DetailedHealth:
health = NewDetailedHealth(health.Status(), health.Description(), nil)
}
if !deep {
return health
}
switch health.(type) {
case *CompositeHealth:
for k, v := range health.(*CompositeHealth).Components {
health.(*CompositeHealth).Components[k] = sanitizeDetails(v, deep)
}
case CompositeHealth:
comps := map[string]Health{}
for k, v := range health.(CompositeHealth).Components {
comps[k] = sanitizeDetails(v, deep)
}
return NewCompositeHealth(health.Status(), health.Description(), comps)
}
return health
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package health
import (
"context"
"github.com/cisco-open/go-lanai/pkg/web"
)
type PingIndicator struct {
}
func (b PingIndicator) Name() string {
return "ping"
}
func (b PingIndicator) Health(ctx context.Context, options Options) Health {
// very basic check: if the given context is *gin.Context, it means the health check is invoked via web endpoint.
// therefore the web framework is still working
if g := web.GinContext(ctx); g != nil {
return NewDetailedHealth(StatusUp, "ping", nil)
}
return NewDetailedHealth(StatusUnknown, "ping", nil)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package health
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "actuator-health",
Precedence: bootstrap.ActuatorPrecedence,
Options: []fx.Option{
fx.Provide(
BindHealthProperties,
NewSystemHealthRegistrar,
provideInterfaces,
),
},
}
func Use() {
bootstrap.Register(Module)
}
func provideInterfaces(reg *SystemHealthRegistrar) (Registrar, Indicator) {
return reg, reg.Indicator
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package health
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"strings"
)
const (
HealthPropertiesPrefix = "management.endpoint.health"
)
type HealthProperties struct {
Status StatusProperties `json:"status"`
// When to show components. If not specified the 'show-details' setting will be used.
ShowComponents *ShowMode `json:"show-components"`
// When to show full health details.
ShowDetails ShowMode `json:"show-details"`
// Permisions used to determine whether or not a user is authorized to be shown details.
// When empty, all authenticated users are authorized.
Permissions utils.CommaSeparatedSlice `json:"permissions"`
}
type StatusOrders []Status
// encoding.TextUnmarshaler
func (o *StatusOrders) UnmarshalText(data []byte) error {
result := []Status{}
split := strings.Split(string(data), ",")
for _, s := range split {
s = strings.TrimSpace(s)
status := StatusUnknown
if e := status.UnmarshalText([]byte(s)); e != nil {
return e
}
result = append(result, status)
}
*o = result
return nil
}
type StatusProperties struct {
// Comma-separated list of health statuses in order of severity.
Orders StatusOrders `json:"order"`
// Mapping of health statuses to HTTP status codes. By default, registered health
// statuses map to sensible defaults (for example, UP maps to 200).
ScMapping map[Status]int `json:"http-mapping"`
}
//NewSessionProperties create a SessionProperties with default values
func NewHealthProperties() *HealthProperties {
return &HealthProperties{
Status: StatusProperties{
Orders: StatusOrders{StatusDown, StatusOutOfService, StatusUp, StatusUnknown},
ScMapping: map[Status]int{},
},
Permissions: []string{},
}
}
//BindHealthProperties create and bind SessionProperties, with a optional prefix
func BindHealthProperties(ctx *bootstrap.ApplicationContext) HealthProperties {
props := NewHealthProperties()
if err := ctx.Config().Bind(props, HealthPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind HealthProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package health
import (
"fmt"
"go.uber.org/fx"
)
// SystemHealthRegistrar implements Registrar
type SystemHealthRegistrar struct {
Indicator *CompositeIndicator
DetailsDisclosure DetailsDisclosureControl
ComponentsDisclosure ComponentsDisclosureControl
}
type regDI struct {
fx.In
Properties HealthProperties
}
func NewSystemHealthRegistrar(di regDI) *SystemHealthRegistrar {
return &SystemHealthRegistrar{
Indicator: &CompositeIndicator{
name: "system",
delegates: []Indicator{
PingIndicator{},
},
aggregator: NewSimpleStatusAggregator(func(opt *AggregateOption) {
if len(di.Properties.Status.Orders) != 0 {
opt.StatusOrders = di.Properties.Status.Orders
}
}),
},
}
}
// Register configure SystemHealthRegistrar
// supported input parameters are:
// - Indicator
// - StatusAggregator
// - DetailsDisclosureControl
// - ComponentsDisclosureControl
// - DisclosureControl
func (i *SystemHealthRegistrar) Register(items ...interface{}) error {
for _, v := range items {
if e := i.register(v); e != nil {
return e
}
}
return nil
}
func (i *SystemHealthRegistrar) MustRegister(items ...interface{}) {
if e := i.Register(items...); e != nil {
panic(e)
}
}
func (i *SystemHealthRegistrar) register(item interface{}) error {
switch v := item.(type) {
case []interface{}:
return i.Register(v...)
case Indicator:
i.Indicator.Add(v)
case StatusAggregator:
i.Indicator.aggregator = v
case DisclosureControl:
i.DetailsDisclosure = v
i.ComponentsDisclosure = v
case DetailsDisclosureControl:
i.DetailsDisclosure = v
case ComponentsDisclosureControl:
i.ComponentsDisclosure = v
default:
return fmt.Errorf("unsupported item %T", item)
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package info
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/appconfig"
)
const (
ID = "info"
EnableByDefault = true
infoPropertiesPrefix = "info"
)
type Input struct {
Name string `uri:"name"`
}
type Info map[string]interface{}
// InfoEndpoint implements actuator.Endpoint, actuator.WebEndpoint
//goland:noinspection GoNameStartsWithPackageName
type InfoEndpoint struct {
actuator.WebEndpointBase
appConfig appconfig.ConfigAccessor
}
func newEndpoint(di regDI) *InfoEndpoint {
ep := InfoEndpoint{
appConfig: di.AppContext.Config().(appconfig.ConfigAccessor),
}
ep.WebEndpointBase = actuator.MakeWebEndpointBase(func(opt *actuator.EndpointOption) {
opt.Id = ID
opt.Ops = []actuator.Operation{
actuator.NewReadOperation(ep.Read),
}
opt.Properties = &di.MgtProperties.Endpoints
opt.EnabledByDefault = EnableByDefault
})
return &ep
}
// Read never returns error
func (ep *InfoEndpoint) Read(ctx context.Context, input *Input) (interface{}, error) {
info := Info{}
if e := ep.appConfig.Bind(&info, infoPropertiesPrefix); e != nil {
return Info{}, e
}
buildInfo := map[string]interface{}{}
if e := ep.appConfig.Bind(&buildInfo, appconfig.PropertyKeyBuildInfo); e == nil {
info["build-info"] = buildInfo
}
logger.WithContext(ctx).Debugf("info %v", info)
if input.Name == "" {
return info, nil
}
if v, ok := info[input.Name]; ok {
return v, nil
}
return Info{}, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package info
import (
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
var logger = log.New("ACTR.Info")
var Module = &bootstrap.Module{
Name: "actuator-info",
Precedence: actuator.MinActuatorPrecedence,
Options: []fx.Option{
fx.Invoke(register),
},
}
func Register() {
bootstrap.Register(Module)
}
type regDI struct {
fx.In
Registrar *actuator.Registrar
MgtProperties actuator.ManagementProperties
AppContext *bootstrap.ApplicationContext
}
func register(di regDI) {
ep := newEndpoint(di)
di.Registrar.MustRegister(ep)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuator
import (
"embed"
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/actuator/alive"
"github.com/cisco-open/go-lanai/pkg/actuator/apilist"
"github.com/cisco-open/go-lanai/pkg/actuator/env"
health "github.com/cisco-open/go-lanai/pkg/actuator/health/endpoint"
"github.com/cisco-open/go-lanai/pkg/actuator/info"
"github.com/cisco-open/go-lanai/pkg/actuator/loggers"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
//go:embed defaults-actuator.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "actuate-config",
Precedence: actuator.MinActuatorPrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
},
}
func Use() {
bootstrap.Register(actuator.Module)
bootstrap.Register(Module)
info.Register()
health.Register()
env.Register()
alive.Register()
apilist.Register()
loggers.Register()
}
/**************************
Initialize
***************************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package loggers
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/web"
"net/http"
"strings"
)
const (
ID = "loggers"
EnableByDefault = true
)
var (
allLevels = []log.LoggingLevel{
log.LevelOff, log.LevelDebug, log.LevelInfo, log.LevelWarn, log.LevelError,
}
)
type ReadInput struct {
Name string `uri:"name"`
}
type WriteInput struct {
Prefix string `uri:"name" binding:"required"`
ConfiguredLevel *log.LoggingLevel `json:"configuredLevel"`
}
type ReadOutput struct {
Levels []log.LoggingLevel `json:"levels"`
Loggers map[string]LoggerLevel `json:"loggers"`
}
type LoggerLevel struct {
EffectiveLevel *log.LoggingLevel `json:"effectiveLevel,omitempty"`
ConfiguredLevel *log.LoggingLevel `json:"configuredLevel,omitempty"`
}
// LoggersEndpoint implements actuator.Endpoint, actuator.WebEndpoint
//goland:noinspection GoNameStartsWithPackageName
type LoggersEndpoint struct {
actuator.WebEndpointBase
pathSuffix map[actuator.Operation]string
}
func newEndpoint(di regDI) *LoggersEndpoint {
ep := LoggersEndpoint{}
ep.pathSuffix = map[actuator.Operation]string{
actuator.NewReadOperation(ep.ReadAll): "",
actuator.NewReadOperation(ep.ReadAll): "/",
actuator.NewReadOperation(ep.ReadByName): "/:name",
actuator.NewWriteOperation(ep.Write): "/:name",
}
ops := make([]actuator.Operation, 0, len(ep.pathSuffix))
for k := range ep.pathSuffix {
ops = append(ops, k)
}
ep.WebEndpointBase = actuator.MakeWebEndpointBase(func(opt *actuator.EndpointOption) {
opt.Id = ID
opt.Ops = ops
opt.Properties = &di.MgtProperties.Endpoints
opt.EnabledByDefault = EnableByDefault
})
return &ep
}
// Mappings implements WebEndpoint
func (ep *LoggersEndpoint) Mappings(op actuator.Operation, group string) ([]web.Mapping, error) {
builder, e := ep.RestMappingBuilder(op, group, ep.MappingPath, ep.MappingName)
if e != nil {
return nil, e
}
if op.Mode() == actuator.OperationWrite {
builder.EncodeResponseFunc(ep.WriteEncodeResponse)
}
return []web.Mapping{builder.Build()}, nil
}
func (ep *LoggersEndpoint) MappingPath(op actuator.Operation, props *actuator.WebEndpointsProperties) string {
path := ep.WebEndpointBase.MappingPath(op, props)
suffix, _ := ep.pathSuffix[op]
return path + suffix
}
// ReadAll returns all loggers
func (ep *LoggersEndpoint) ReadAll(_ context.Context, _ *struct{}) (interface{}, error) {
cfgs := log.Levels("")
out := ReadOutput{
Levels: allLevels,
Loggers: map[string]LoggerLevel{},
}
for _, v := range cfgs {
out.Loggers[v.Name] = LoggerLevel{
EffectiveLevel: v.EffectiveLevel,
ConfiguredLevel: v.ConfiguredLevel,
}
}
return out, nil
}
// ReadByName find one logger by name
func (ep *LoggersEndpoint) ReadByName(_ context.Context, in *ReadInput) (interface{}, error) {
cfgs := log.Levels(in.Name)
for k, v := range cfgs {
if k == strings.ToLower(in.Name) || v.Name == in.Name {
return &LoggerLevel{
EffectiveLevel: v.EffectiveLevel,
ConfiguredLevel: v.ConfiguredLevel,
}, nil
}
}
return nil, web.NewHttpError(http.StatusNotFound, fmt.Errorf("logger with name %s not found", in.Name))
}
// Write update logger levels
func (ep *LoggersEndpoint) Write(_ context.Context, in *WriteInput) (interface{}, error) {
log.SetLevel(in.Prefix, in.ConfiguredLevel)
return nil, nil
}
func (ep *LoggersEndpoint) WriteEncodeResponse(_ context.Context, rw http.ResponseWriter, _ interface{}) error {
rw.WriteHeader(http.StatusNoContent)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package loggers
import (
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
//var logger = log.New("ACTR.LoggerLevel")
var Module = &bootstrap.Module{
Name: "actuator-loggers",
Precedence: actuator.MinActuatorPrecedence,
Options: []fx.Option{
fx.Invoke(register),
},
}
func Register() {
bootstrap.Register(Module)
}
type regDI struct {
fx.In
Registrar *actuator.Registrar
MgtProperties actuator.ManagementProperties
}
func register(di regDI) {
ep := newEndpoint(di)
di.Registrar.MustRegister(ep)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuator
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
var logger = log.New("Actuator")
var Module = &bootstrap.Module{
Name: "actuate",
Precedence: MaxActuatorPrecedence,
Options: []fx.Option{
fx.Provide(NewRegistrar, BindManagementProperties),
fx.Invoke(initialize),
},
}
/**************************
Provider
***************************/
/**************************
Initialize
***************************/
func initialize(registrar *Registrar, di initDI) {
if e := registrar.initialize(di); e != nil {
panic(e)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuator
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
)
const (
ManagementPropertiesPrefix = "management"
)
type ManagementProperties struct {
Enabled bool `json:"enabled"`
Endpoints EndpointsProperties `json:"endpoints"`
BasicEndpoint map[string]BasicEndpointProperties `json:"endpoint"`
Security SecurityProperties `json:"security"`
}
type EndpointsProperties struct {
EnabledByDefault bool `json:"enabled-by-default"`
Web WebEndpointsProperties `json:"web"`
}
type WebEndpointsProperties struct {
BasePath string `json:"base-path"`
Mappings map[string]string `json:"path-mapping"`
Exposure WebExposureProperties `json:"exposure"`
}
type WebExposureProperties struct {
// Endpoint IDs that should be included or '*' for all.
Include utils.StringSet `json:"include"`
// Endpoint IDs that should be excluded or '*' for all.
Exclude utils.StringSet `json:"exclude"`
}
type BasicEndpointProperties struct {
Enabled *bool `json:"enabled"`
}
type SecurityProperties struct {
EnabledByDefault bool `json:"enabled-by-default"`
Permissions utils.CommaSeparatedSlice `json:"permissions"`
Endpoints map[string]EndpointSecurityProperties `json:"endpoint"`
}
type EndpointSecurityProperties struct {
Enabled *bool `json:"enabled"`
Permissions utils.CommaSeparatedSlice `json:"permissions"`
}
//NewManagementProperties create a ManagementProperties with default values
func NewManagementProperties() *ManagementProperties {
return &ManagementProperties{
Enabled: true,
Endpoints: EndpointsProperties{
Web: WebEndpointsProperties{
BasePath: "/manage",
Mappings: map[string]string{},
Exposure: WebExposureProperties{
Include: utils.NewStringSet("*"),
Exclude: utils.NewStringSet(),
},
},
},
Security: SecurityProperties{
EnabledByDefault: false,
Permissions: []string{},
Endpoints: map[string]EndpointSecurityProperties{
"alive": {
Enabled: utils.ToPtr(false),
},
"info": {
Enabled: utils.ToPtr(false),
},
"health": {
Enabled: utils.ToPtr(false),
},
},
},
BasicEndpoint: map[string]BasicEndpointProperties{},
}
}
//BindManagementProperties create and bind SessionProperties, with a optional prefix
func BindManagementProperties(ctx *bootstrap.ApplicationContext) ManagementProperties {
props := NewManagementProperties()
if err := ctx.Config().Bind(props, ManagementPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind ManagementProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuator
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/gin-gonic/gin"
"go.uber.org/fx"
"net/http"
)
type constructDI struct {
fx.In
Properties ManagementProperties
}
type initDI struct {
fx.In
ApplicationContext *bootstrap.ApplicationContext
WebRegistrar *web.Registrar `optional:"true"`
SecurityRegistrar security.Registrar `optional:"true"`
}
type Registrar struct {
initialized bool
properties *ManagementProperties
endpoints []Endpoint
securityConfigurer security.Configurer
securityCustomizer SecurityCustomizer
accessCustomizer AccessControlCustomizer
}
func NewRegistrar(di constructDI) *Registrar {
return &Registrar{
properties: &di.Properties,
}
}
func (r *Registrar) initialize(di initDI) error {
if r.initialized {
return fmt.Errorf("attempting to initialize actuator twice")
}
defer func() {
r.initialized = true
}()
// install web endpoints
webEndpoints, e := r.installWebEndpoints(di.WebRegistrar)
if e != nil {
return e
}
logger.WithContext(di.ApplicationContext).
Info(fmt.Sprintf("registered web endponts %v", webEndpoints.EndpointIDs()))
// install security
if e := r.installWebSecurity(di.SecurityRegistrar, webEndpoints); e != nil {
return e
}
return nil
}
func (r *Registrar) MustRegister(items ...interface{}) {
if e := r.Register(items...); e != nil {
panic(e)
}
}
func (r *Registrar) Register(items ...interface{}) error {
for _, item := range items {
if e := r.register(item); e != nil {
return e
}
}
return nil
}
func (r *Registrar) register(item interface{}) (err error) {
if r.initialized {
return fmt.Errorf("attempting to register actuator items after actuator has been initialized")
}
switch v := item.(type) {
case Endpoint:
r.endpoints = append(r.endpoints, v)
case []interface{}:
err = r.Register(v...)
case SecurityCustomizer:
r.securityCustomizer = v
case AccessControlCustomizer:
r.accessCustomizer = v
default:
return fmt.Errorf("unsupported actuator type [%T]", item)
}
return err
}
func (r *Registrar) installWebEndpoints(reg *web.Registrar) (WebEndpoints, error) {
if reg == nil || !r.properties.Enabled {
return nil, nil
}
result := WebEndpoints{}
for _, ep := range r.endpoints {
if mappings, e := r.installWebEndpoint(reg, ep); e != nil {
return nil, e
} else if len(mappings) != 0 {
result[ep.Id()] = mappings
}
}
return result, nil
}
func (r *Registrar) installWebEndpoint(reg *web.Registrar, endpoint Endpoint) ([]web.Mapping, error) {
if reg == nil || !r.isEndpointEnabled(endpoint) || !r.shouldExposeToWeb(endpoint) {
return nil, nil
}
ops := endpoint.Operations()
mappings := make([]web.Mapping, 0, len(ops))
paths := utils.NewStringSet()
for _, op := range ops {
opMappings, e := endpoint.(WebEndpoint).Mappings(op, "")
if e != nil {
return nil, e
}
if e := reg.Register(opMappings); e != nil {
return nil, e
}
mappings = append(mappings, opMappings...)
for _, m := range opMappings {
if route, ok := m.(web.RoutedMapping); ok {
paths.Add(route.Group() + route.Path())
}
}
}
// add OPTIONS route
for path := range paths {
m := mapping.Options(path).HandlerFunc(optionsHttpHandlerFunc()).Build()
if e := reg.Register(m); e != nil {
return nil, e
}
}
return mappings, nil
}
func (r *Registrar) installWebSecurity(reg security.Registrar, endpoints WebEndpoints) error {
if reg == nil {
return nil
}
configurer := newActuatorSecurityConfigurer(r.properties, endpoints, r.securityCustomizer, r.accessCustomizer)
reg.Register(configurer)
return nil
}
/*******************************
internal
********************************/
func (r *Registrar) isEndpointEnabled(endpoint Endpoint) bool {
if !r.properties.Enabled {
return false
}
if basic, ok := r.properties.BasicEndpoint[endpoint.Id()]; !ok || basic.Enabled == nil {
// not explicitly specified, use default
return endpoint.EnabledByDefault() || r.properties.Endpoints.EnabledByDefault
} else {
return *basic.Enabled
}
}
func (r *Registrar) shouldExposeToWeb(endpoint Endpoint) bool {
if _, ok := endpoint.(WebEndpoint); !ok {
return false
}
includeAll := r.properties.Endpoints.Web.Exposure.Include.Has("*")
include := r.properties.Endpoints.Web.Exposure.Include.Has(endpoint.Id())
excludeAll := r.properties.Endpoints.Web.Exposure.Exclude.Has("*")
exclude := r.properties.Endpoints.Web.Exposure.Include.Has(endpoint.Id())
switch {
case !excludeAll && !exclude && (includeAll || include):
// no exclusion & include is set
return true
case !exclude && include:
// no explicit exclusion & explicit inclusion
return true
default:
// explicit exclusion or implicit exclusion without explicit inclusion
return false
}
}
func optionsHttpHandlerFunc() gin.HandlerFunc {
return func(gc *gin.Context) {
gc.Status(http.StatusOK)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuator
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/tokenauth"
matcherutils "github.com/cisco-open/go-lanai/pkg/utils/matcher"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"net/http"
"regexp"
)
/*******************************
Interfaces
********************************/
// SecurityCustomizer is a single SecurityCustomizer can be registered via Registrar
// SecurityCustomizer is typically responsible to setup authentication scheme
// it should not configure access control, which is configured per-endpoint via properties or AccessControlCustomizer
type SecurityCustomizer interface {
Customize(ws security.WebSecurity)
}
// SecurityCustomizerFunc convert a function to interface SecurityCustomizer
type SecurityCustomizerFunc func(ws security.WebSecurity)
func (f SecurityCustomizerFunc) Customize(ws security.WebSecurity) {
f(ws)
}
// AccessControlCustomizer Similar to SecurityCustomizer, but is used to customize access control of each endpoint.
// Implementations of AccessControlCustomizer can be registered via Registrar, and NewAccessControlByPermissions is
// used if no other customizer is registered.
// Also See NewSimpleAccessControl, NewAccessControlByPermissions, NewAccessControlByScopes
type AccessControlCustomizer interface {
Customize(ac *access.AccessControlFeature, epId string, paths []string)
}
// AccessControlCustomizeFunc convert a function to interface AccessControlCustomizer
type AccessControlCustomizeFunc func(ac *access.AccessControlFeature, epId string, paths []string)
func (f AccessControlCustomizeFunc) Customize(ac *access.AccessControlFeature, epId string, paths []string) {
f(ac, epId, paths)
}
/*******************************
Security Configurer
********************************/
// actuatorSecurityConfigurer implements security.Configurer
type actuatorSecurityConfigurer struct {
properties *ManagementProperties
endpoints WebEndpoints
customizer SecurityCustomizer
acCustomizer AccessControlCustomizer
}
func newActuatorSecurityConfigurer(properties *ManagementProperties, endpoints WebEndpoints, customizer SecurityCustomizer, acCustomizer AccessControlCustomizer) *actuatorSecurityConfigurer {
if customizer == nil {
customizer = NewTokenAuthSecurity()
}
if acCustomizer == nil {
acCustomizer = NewAccessControlByPermissions(properties.Security)
}
return &actuatorSecurityConfigurer{
properties: properties,
endpoints: endpoints,
customizer: customizer,
acCustomizer: acCustomizer,
}
}
func (c *actuatorSecurityConfigurer) Configure(ws security.WebSecurity) {
if c.customizer != nil {
c.customizer.Customize(ws)
}
path := fmt.Sprintf("%s/**", c.properties.Endpoints.Web.BasePath)
ws.Route(matcher.RouteWithPattern(path).And(matcherutils.Not(matcher.RouteWithMethods(http.MethodOptions)))).
With(errorhandling.New())
// configure access control based on customizer and installed web endpoints
ac := access.Configure(ws)
for k, _ := range c.endpoints {
c.acCustomizer.Customize(ac, k, c.endpoints.Paths(k))
}
// fallback configuration
if c.properties.Security.EnabledByDefault {
ac.Request(matcher.AnyRequest()).Authenticated()
} else {
ac.Request(matcher.AnyRequest()).PermitAll()
}
}
/*******************************
Common Implementation
********************************/
// NewTokenAuthSecurity returns a SecurityCustomizer config actuator security to use OAuth2 token auth.
// This is the default SecurityCustomizer if no other SecurityCustomizer is registered
func NewTokenAuthSecurity() SecurityCustomizer {
return SecurityCustomizerFunc(func(ws security.WebSecurity) {
ws.With(tokenauth.New())
})
}
// NewSimpleAccessControl is a convenient AccessControlCustomizer constructor that create simple access
// control rule to ALL paths of each endpoint.
// A mapper function is required to convert each endpoint ID to its corresponding access.ControlFunc
func NewSimpleAccessControl(acCreator func(epId string) access.ControlFunc) AccessControlCustomizer {
return AccessControlCustomizeFunc(func(ac *access.AccessControlFeature, epId string, paths []string) {
if len(paths) == 0 {
return
}
// configure request matchers
m := pathToRequestPattern(paths[0])
for _, p := range paths[1:] {
m = m.Or(pathToRequestPattern(p))
}
// configure access control
controlFunc := acCreator(epId)
ac.Request(m).AllowIf(controlFunc)
})
}
// NewAccessControlByPermissions returns a AccessControlCustomizer that uses SecurityProperties and given default
// permissions to setup access control of each endpoint.
// 1. If security of any particular endpoint is not enabled, access.PermitAll is used
// 2. If no permissions are configured in the properties and no defaults are given, access.Authenticated is used
//
// This is the default AccessControlCustomizer if no other AccessControlCustomizer is registered
func NewAccessControlByPermissions(properties SecurityProperties, defaultPerms ...string) AccessControlCustomizer {
return NewSimpleAccessControl(func(epId string) access.ControlFunc {
enabled, permissions := collectSecurityFacts(epId, &properties)
if len(permissions) == 0 {
permissions = defaultPerms
}
switch {
case !enabled:
return access.PermitAll
case len(permissions) == 0:
return access.Authenticated
default:
return access.HasPermissions(permissions...)
}
})
}
// NewAccessControlByScopes returns a AccessControlCustomizer that uses SecurityProperties and given default
// approved scopes to setup access control of each endpoint.
// "usePermissions" indicate if we should use permissions configured in SecurityProperties for scope checking
//
// 1. If security of any particular endpoint is not enabled, access.PermitAll is used
// 2. If usePermissions is true but no permissions are configured in SecurityProperties, defaultScopes is used to resolve scoes
// 3. If no scopes are configured (regardless if usePermissions is enabled), access.Authenticated is used
//
// Note: This customizer is particularly useful for client_credentials grant type
func NewAccessControlByScopes(properties SecurityProperties, usePermissions bool, defaultScopes ...string) AccessControlCustomizer {
return NewSimpleAccessControl(func(epId string) access.ControlFunc {
// first grab some facts
scopes := defaultScopes
enabled, permissions := collectSecurityFacts(epId, &properties)
// if usePermissions is true, we use permissions from properties to for scope checking
if usePermissions && len(permissions) != 0 {
scopes = permissions
}
// then choose access control func
switch {
case !enabled:
return access.PermitAll
case len(scopes) == 0:
return access.Authenticated
default:
return tokenauth.ScopesApproved(scopes...)
}
})
}
var pathVarRegex = regexp.MustCompile(`:[a-zA-Z0-9\-_]+`)
// pathToRequestPattern convert path variables to wildcard request pattern
// "/path/to/:any/endpoint" would converted to "/path/to/*/endpoint
func pathToRequestPattern(path string) web.RequestMatcher {
patternStr := pathVarRegex.ReplaceAllString(path, "*")
return matcher.RequestWithPattern(patternStr)
}
func collectSecurityFacts(epId string, properties *SecurityProperties, defaults ...string) (enabled bool, permissions []string) {
permissions = defaults
enabled = properties.EnabledByDefault
if len(properties.Permissions) != 0 {
permissions = properties.Permissions
}
if props, ok := properties.Endpoints[epId]; ok {
permissions = props.Permissions
if props.Enabled != nil {
enabled = *props.Enabled
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package args
import (
"os"
"strings"
)
// ExtraFlags parse original CLI flags (before standalone "--") and accepts both --flag=value and --flag value format.
// This method is used to parse the flags not pre-defined by our application. (i.e. flags like --help, --profile)
func ExtraFlags(skip func(name string) bool) (extras map[string]string) {
extras = make(map[string]string)
args := os.Args[1:]
for n := 0; n < len(args); n++ {
v := args[n]
if len(v) < 2 {
continue
}
if v == "--" {
break
}
if !strings.HasPrefix(v, "--") {
continue
}
v = v[2:]
split := strings.SplitN(v, "=", 2)
if len(split) == 2 && !skip(split[0]){
key := split[0]
extras[key] = split[1]
} else if n == len(args)-1 {
continue
} else if strings.HasPrefix(args[n+1], "--") {
continue
} else if skip(v) {
// skip this flag. we do n++ since if we ended up here, we are expecting the next argument to be the value
n++
} else {
key := v
extras[key] = args[n+1]
n++
}
}
return extras
}
// ExtraKVArgs parse original CLI arguments (after standalone "--") and accepts flag=value
func ExtraKVArgs(args []string) (extras map[string]string) {
extras = make(map[string]string)
for _, v := range args {
split := strings.SplitN(v, "=", 2)
switch {
case len(split) == 2:
extras[split[0]] = split[1]
case len(split) == 1:
extras[split[0]] = ""
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package cliprovider
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/appconfig/args"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/spf13/pflag"
"sync"
)
var (
declaredFlagsMapping = map[string]string{
bootstrap.CliFlagActiveProfile: appconfig.PropertyKeyActiveProfiles,
bootstrap.CliFlagAdditionalProfile: appconfig.PropertyKeyAdditionalProfiles,
bootstrap.CliFlagConfigSearchPath: appconfig.PropertyKeyConfigFileSearchPath,
}
)
// note for a command like ./app --active-profiles develop --dynamic.flag.example=1 -- kvarg-example=b
// --active-profiles develop is the decalred flag, because we declared active-profiles as a known command line flag.
// i.e. --help will display a description about this flag's usage
// --dynamic.flag.example=1 is the dynamic flag, we did not declare this flag at development time.
// kvarg-example=b is the command's argument. Since it's in kv format, we will extract them into kvArgs, and process them as properties
type ConfigProvider struct {
appconfig.ProviderMeta
prefix string
declaredFlags map[string]interface{} // flags pre-declared by our command (e.g. --help). Cobra will parse these.
args []string // cobra arguments (args after standalone "--" )
dynamicFlags map[string]string // flags not declared by us, so we need to parse these ourselves.
kvArgs map[string]string // key=value pairs from cobra arguments (args after standalone "--" )
once sync.Once
}
func (configProvider *ConfigProvider) Name() string {
return "command-line"
}
func (configProvider *ConfigProvider) Load(_ context.Context) (loadError error) {
defer func() {
configProvider.Loaded = loadError == nil
}()
configProvider.once.Do(func() {
configProvider.dynamicFlags = args.ExtraFlags(func(name string) bool {
_, exists := configProvider.declaredFlags[name]
return exists
})
configProvider.kvArgs = args.ExtraKVArgs(configProvider.args)
})
settings := make(map[string]interface{})
// dynamic flags
for k, v := range configProvider.dynamicFlags {
settings[k] = utils.ParseString(v)
}
// declared flags
for k, v := range configProvider.declaredFlags {
v = configProvider.convertDeclaredFlag(v)
settings[configProvider.prefix+k] = v
if pk, ok := declaredFlagsMapping[k]; ok {
settings[pk] = v
}
}
// arguments
for k, v := range configProvider.kvArgs {
settings[k] = utils.ParseString(v)
}
// un-flatten
unFlattened, loadError := appconfig.UnFlatten(settings)
if loadError != nil {
return loadError
}
configProvider.Settings = unFlattened
return nil
}
func (configProvider *ConfigProvider) convertDeclaredFlag(value interface{}) interface{} {
switch v := value.(type) {
case pflag.SliceValue:
strSlice := v.GetSlice()
retSlice := make([]interface{}, len(strSlice))
for i, s := range strSlice {
retSlice[i] = utils.ParseString(s)
}
return retSlice
case pflag.Value:
return utils.ParseString(v.String())
case fmt.Stringer:
return utils.ParseString(v.String())
default:
return fmt.Sprintf("%v", value)
}
}
func NewCobraProvider(precedence int, execCtx *bootstrap.CliExecContext, prefix string) *ConfigProvider {
flagSet := make(map[string]interface{})
extractFlag := func(flag *pflag.Flag) {
if flag.Changed {
flagSet[flag.Name] = flag.Value
}
}
execCtx.Cmd.InheritedFlags().VisitAll(extractFlag)
execCtx.Cmd.LocalFlags().VisitAll(extractFlag)
return &ConfigProvider{
ProviderMeta: appconfig.ProviderMeta{Precedence: precedence},
prefix: prefix,
declaredFlags: flagSet,
args: execCtx.Args,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package cliprovider
import (
"context"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
)
const (
defaultConfigSearchPath = "configs"
)
type StaticConfigProvider struct {
appconfig.ProviderMeta
appName string
}
func NewStaticConfigProvider(order int, execCtx *bootstrap.CliExecContext) *StaticConfigProvider {
return &StaticConfigProvider{
ProviderMeta: appconfig.ProviderMeta{
Precedence: order,
},
appName: execCtx.Cmd.Root().Name(),
}
}
func (p *StaticConfigProvider) Name() string {
return "default"
}
func (p *StaticConfigProvider) Load(_ context.Context) (err error) {
defer func(){
p.Loaded = err == nil
}()
settings := map[string]interface{}{}
// Apply application name, profiles, etc
settings[appconfig.PropertyKeyApplicationName] = p.appName
settings[appconfig.PropertyKeyConfigFileSearchPath] = []string{defaultConfigSearchPath}
settings[appconfig.PropertyKeyBuildInfo] = bootstrap.BuildInfoMap
// un-flatten
unFlattened, err := appconfig.UnFlatten(settings)
if err == nil {
p.Settings = unFlattened
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"context"
"dario.cat/mergo"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/pkg/errors"
"strconv"
"strings"
)
var (
logger = log.New("Config")
//ErrNotLoaded = errors.New("Configuration not loaded")
errBindWithConfigBeforeLoaded = errors.New("attempt to bind with config before it's loaded")
)
// properties implements bootstrap.ApplicationConfig
type properties map[string]interface{}
func makeInitialProperties() properties {
return map[string]interface{}{}
}
func (p properties) Value(key string) interface{} {
return value(p, key)
}
// Bind bind values to given target, with consideration of key normalization and place holders.
// The keys from property sources are normalized to snake case if they are camel case.
// Therefore, the binding expects the json tag to be in snake case.
func (p properties) Bind(target interface{}, prefix string) error {
var source interface{} = map[string]interface{}(p)
if len(prefix) > 0 && prefix != "." {
keys := strings.Split(prefix, ".")
for i := 0; i < len(keys); i++ {
if _, ok := source.(map[string]interface{}); ok {
source = source.(map[string]interface{})[keys[i]]
} else {
//prefix doesn't exist, we just don't bind it
return nil
}
}
}
serialized, e := json.Marshal(source)
if e == nil {
e = json.Unmarshal(serialized, target)
}
return e
}
type config struct {
properties
groups []ProviderGroup
providers []Provider //such as yaml auth, commandline etc.
profiles utils.StringSet
isLoaded bool
}
//Load will fail if place holder cannot be resolved due to circular dependency
func (c *config) Load(ctx context.Context, force bool) (err error) {
defer func() {
if err != nil {
c.isLoaded = false
} else {
c.isLoaded = true
}
}()
// sort groups based on order, we process lower priority first
order.SortStable(c.groups, order.OrderedFirstCompareReverse)
// reset all groups if force == true
if force {
c.isLoaded = false
c.profiles = nil
for _, g := range c.groups {
g.Reset()
}
}
// repeatedly process provider groups until list of provider become stable and all loaded
var providers []Provider
final := makeInitialProperties()
// Note about hasNew check: when transiting from bootstrap config to application config,
// and all initial providers are from bootstrap config, all providers are loaded initially.
// However, we still need to re-collect/merge all properties.
// In this case, we need to set hasNew to true in the first iteration.
for hasNew, isFirstIter := true, true; hasNew; {
providers = make([]Provider, 0)
hasNew = isFirstIter
isFirstIter = false
for _, g := range c.groups {
// sort providers based on precedence, lower to higher
group := g.Providers(ctx, final)
order.SortStable(group, order.OrderedFirstCompareReverse)
providers = append(providers, group...)
// Load config from each source if it's not loaded yet
for _, provider := range group {
if !provider.IsLoaded() {
if e := provider.Load(ctx); e != nil {
err = errors.Wrap(e, "Failed to load properties")
return
}
hasNew = true
}
}
}
// If no new provider are loaded we quick without re-merge all sources
if !hasNew {
break
}
// merge properties and deal special merging rules on some properties
// Note all properties returned by Provider should be un-flattened
merged := makeInitialProperties()
additionalProfiles := make([]string, 0)
for _, p := range providers {
if p.GetSettings() == nil {
continue
}
formatted, e := ProcessKeyFormat(p.GetSettings(), NormalizeKey)
if e != nil {
err = errors.Wrap(e, "Failed to format keys before merge")
return
}
if e := mergo.Merge(&merged, properties(formatted), mergo.WithOverride); e != nil {
err = errors.Wrap(e, "Failed to merge properties from property sources")
return
}
// special treatments:
// - PropertyKeyAdditionalProfiles need to be appended instead of overridden
if additionalProfiles, e = mergeAdditionalProfiles(additionalProfiles, formatted); e != nil {
return e
}
}
if e := setValue(merged, PropertyKeyAdditionalProfiles, additionalProfiles, true); e != nil {
return e
}
final = merged
}
// resolve placeholder
if err = resolve(ctx, final); err != nil {
return
}
c.properties = final
// resolve profiles
c.profiles = utils.NewStringSet()
for _, v := range resolveProfiles(final) {
if v != "" {
c.profiles.Add(v)
}
}
// providers are stored in highest precedence first
l := len(providers)
c.providers = make([]Provider, l)
for i, v := range providers {
c.providers[l-i-1] = v
}
return
}
func (c *config) Value(key string) interface{} {
if !c.isLoaded {
return nil
}
return c.properties.Value(key)
}
func (c *config) Bind(target interface{}, prefix string) error {
if !c.isLoaded {
return errBindWithConfigBeforeLoaded
}
return c.properties.Bind(target, prefix)
}
// Each go through all properties and apply given function.
// It stops at the first error
func (c *config) Each(apply func(string, interface{}) error) error {
return VisitEach(c.properties, apply)
}
func (c *config) Providers() []Provider {
return c.providers
}
func (c *config) Profiles() []string {
return c.profiles.Values()
}
func (c *config) HasProfile(profile string) bool {
return c.profiles.Has(profile)
}
func (c *config) ProviderGroups() []ProviderGroup {
return c.groups
}
/*********************
Helpers
*********************/
func value(nested map[string]interface{}, flatKey string) (ret interface{}) {
e := visit(nested, flatKey, func(_ string, v interface{}, isLeaf bool, _ int) interface{} {
if isLeaf {
ret = v
}
return nil
})
if e != nil {
return nil
}
return
}
// setValue set given val in map using flat key.
// Note: this method won't create intermediate node if it already exist.
// This means out of bound error is still possible
func setValue(nested map[string]interface{}, flatKey string, val interface{}, createIntermediateNodes bool) error {
return visit(nested, flatKey, func(_ string, v interface{}, isLeaf bool, expectedSliceLen int) interface{} {
switch {
case isLeaf:
return val
case v != nil || !createIntermediateNodes:
// non leaf item, we do nothing if existing value is not nil
return nil
case expectedSliceLen > 0:
// create intermediate slice
s := make([]interface{}, expectedSliceLen)
for i := 0; i < expectedSliceLen; i++ {
s[i] = map[string]interface{}{}
}
return s
default:
// create intermediate map
return map[string]interface{}{}
}
})
}
type visitFunc func(keyPath string, v interface{}, isLeaf bool, expectedSliceLen int) interface{}
// visit traverse the given tree (map) along the path represented as flatKey (e.g. flat.key[0].path)
// it calls overrideFunc with each node's partial key path and its value. if returned value is non-nil, it will
// replace the node
func visit(nested map[string]interface{}, flatKey string, overrideFunc visitFunc) error {
targetKey := NormalizeKey(flatKey)
nestedKeys := UnFlattenKey(targetKey)
partialKey := ""
var tmp interface{} = nested
for i, nestedKey := range nestedKeys {
// set index if nested key is "[index]" format
var index = -1
indexStart := strings.Index(nestedKey, "[")
indexEnd := strings.Index(nestedKey, "]")
if indexStart > -1 && indexEnd > -1 {
indexStr := nestedKey[indexStart+1 : indexEnd]
index, _ = strconv.Atoi(indexStr)
nestedKey = nestedKey[0:indexStart]
}
m, ok := tmp.(map[string]interface{})
if !ok {
return fmt.Errorf("incorrect type at key path %s. expected map[string]interface{}, but got %T", partialKey, tmp)
}
// get value and attempt to override
tmp = m[nestedKey]
isLast := i == len(nestedKeys) - 1
partialKey = joinKeyPaths(partialKey, nestedKey)
if v := overrideFunc(partialKey, tmp, isLast && index < 0, index + 1); v != nil {
m[nestedKey] = v
tmp = v
}
if index >= 0 {
// slice
s, ok := tmp.([]interface{})
if !ok || len(s) <= index {
return fmt.Errorf("index %d out of bound (%d) at key path %s", index, len(s), partialKey)
}
// attempt to override
tmp = s[index]
partialKey = joinKeyPaths(partialKey, nestedKey)
if v := overrideFunc(partialKey, tmp, isLast, -1); v != nil {
s[index] = v
tmp = v
}
}
}
return nil
}
func joinKeyPaths(left string, right interface{}) string {
switch r := right.(type) {
case string:
switch {
case left == "":
return r
case right == "":
return left
default:
return left + "." + r
}
case int:
return left + "[" + strconv.Itoa(r) + "]"
default:
return ""
}
}
func mergeAdditionalProfiles(profiles []string, src map[string]interface{}) ([]string, error) {
raw := value(src, PropertyKeyAdditionalProfiles)
switch v := raw.(type) {
case nil:
return profiles, nil
case string:
profiles = append(profiles, v)
case []string:
profiles = append(profiles, v...)
case []interface{}:
for i, p := range v {
s, ok := p.(string)
if !ok {
return nil, fmt.Errorf("invalid type %T at key path %s[%d]", v, PropertyKeyAdditionalProfiles, i)
}
profiles = append(profiles, s)
}
default:
return nil, fmt.Errorf("invalid type %T at key path %s", raw, PropertyKeyAdditionalProfiles)
}
return profiles, nil
}
/*********************
Placeholder
*********************/
func resolve(ctx context.Context, nested map[string]interface{}) error {
doResolve := func(key string, value interface{}) error {
_, e := resolveValue(ctx, key, value, nested, nil)
return e
}
if e := VisitEach(nested, doResolve); e != nil {
return e
}
return nil
}
// resolveValue recursively resolve the value of key by replacing placeholders with actual value
// Note: here the key is the flattened key
func resolveValue(ctx context.Context, key string, val interface{}, source map[string]interface{}, visited []string) (resolvedVal interface{}, err error) {
//if value is not string, no need to resolve it further
if _, ok := val.(string); !ok {
return val, nil
}
placeholders, isEmbedded, e := parsePlaceHolder(val.(string))
if e != nil {
return "", e
} else if len(placeholders) == 0 {
return val, nil
}
// check for circular reference
visited = append(visited, key)
logger.WithContext(ctx).Debugf("resolving key: " + key)
for _, ph := range placeholders {
for i, k := range visited {
if strings.Compare(k, ph.key) == 0 {
circular := strings.Join(visited[i:], "->") + "->" + ph.key
return "", fmt.Errorf("placeholder ${%s} can't be resolved due to circular reference: %s", ph.key, circular)
}
}
}
resolvedKV := make(map[string]interface{})
resolvedPlaceholder := make(map[string]placeholder)
for _, ph := range placeholders {
v := value(source, ph.key)
switch resolved, e := resolveValue(ctx, ph.key, v, source, visited); {
case e == nil && resolved == nil:
// cannot resolve value
if ph.defaultVal != nil {
resolvedKV[ph.key] = ph.defaultVal
resolvedPlaceholder[ph.key] = ph
}
case e == nil:
resolvedKV[ph.key] = resolved
resolvedPlaceholder[ph.key] = ph
case e != nil && ph.defaultVal != nil:
logger.WithContext(ctx).Warnf(e.Error())
resolvedKV[ph.key] = ph.defaultVal
resolvedPlaceholder[ph.key] = ph
default:
return nil, e
}
}
// embedded means the placeholder is embedded in the value string, either with the format of
// "somestring${a}" or "${a}${b}"
// therefore the resolvedVal placeholders must be all strings as well, otherwise we can't concatenate them together.
var resolvedValue interface{}
if isEmbedded {
str := val.(string)
for phKey, resolved := range resolvedKV {
str = strings.Replace(str, resolvedPlaceholder[phKey].String(), fmt.Sprint(resolved), -1)
}
resolvedValue = str
} else { //if not embedded, the entire value must have just been a single placeholder.
resolvedValue = resolvedKV[placeholders[0].key]
}
if e := setValue(source, key, resolvedValue, false); e != nil {
return nil, e
}
return resolvedValue, nil
}
const placeHolderPrefix = "${"
const placeHolderSuffix = "}"
const placeHolderDefaultDelimiter = ":"
type bracket struct {
value string
index int
}
type placeholder struct {
key string
defaultVal interface{}
}
func (ph placeholder) String() string {
if ph.defaultVal == nil {
return fmt.Sprintf("%s%s%s", placeHolderPrefix, ph.key, placeHolderSuffix)
}
return fmt.Sprintf("%s%s%s%v%s", placeHolderPrefix, ph.key, placeHolderDefaultDelimiter, ph.defaultVal, placeHolderSuffix)
}
// embedded means the placeholder is embedded in the value string, either with the format of "somestring${a}" or "${a}${b}"
// Note: when default value is present e.g. "${non-exist.key:default_value}", the type of default value is unknown
// (information about whether the value was quoted is lost during YAML parsing).
// We guess the type based on the default value using strconv package:
// - 100 -> json.Number
// - 100.0 -> float
// - true/false -> bool
// - other values -> string
func parsePlaceHolder(strValue string) (placeholders []placeholder, isEmbedded bool, error error) {
//use this as a stack to check for nested placeholder brackets
//the algorithm is to put left bracket on the stack, and pop it off when we see a right bracket
//this way if the stack is at length greater than 1 when we encounter another left bracket, we have a nested situation
var bracketStack []bracket
for i := 0; i < len(strValue); i++ {
//if we encounters ${
if i <= len(strValue)-len(placeHolderPrefix) && strings.Compare(strValue[i:i+len(placeHolderPrefix)], placeHolderPrefix) == 0 {
bracketStack = append(bracketStack, bracket{placeHolderPrefix, i + 1})
if len(bracketStack) > 1 {
return nil, false, errors.New(strValue + " has nested place holders, which is not supported")
}
}
//if we encounter }
if strings.Compare(strValue[i:i+1], placeHolderSuffix) == 0 {
stackLen := len(bracketStack)
if bracketStack != nil && stackLen >= 1 {
leftBracket := bracketStack[stackLen-1] //gets the top of the stack
bracketStack = bracketStack[:stackLen-1] //pop the top of the stack
split := strings.SplitN(strValue[leftBracket.index+1 : i], placeHolderDefaultDelimiter, 2)
ph := placeholder{
key: split[0],
}
if len(split) > 1 {
ph.defaultVal = utils.ParseString(split[1])
}
placeholders = append(placeholders, ph)
if leftBracket.index > len(placeHolderPrefix)-1 || i < len(strValue)-1 {
isEmbedded = true
}
} //else there's no matching ${, so we skip it
}
}
return placeholders, isEmbedded, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import "github.com/cisco-open/go-lanai/pkg/bootstrap"
const (
PropertyKeyActiveProfiles = "application.profiles.active"
PropertyKeyAdditionalProfiles = "application.profiles.additional"
PropertyKeyConfigFileSearchPath = "config.file.search-path"
PropertyKeyApplicationName = bootstrap.PropertyKeyApplicationName
PropertyKeyBuildInfo = "application.build"
//PropertyKey = ""
)
type ConfigAccessor interface {
bootstrap.ApplicationConfig
Each(apply func(string, interface{}) error) error
// Providers gives effective config providers
Providers() []Provider
Profiles() []string
HasProfile(profile string) bool
}
type BootstrapConfig struct {
config
}
func NewBootstrapConfig(groups ...ProviderGroup) *BootstrapConfig {
return &BootstrapConfig{config: config{groups: groups}}
}
type ApplicationConfig struct {
config
}
func NewApplicationConfig(groups ...ProviderGroup) *ApplicationConfig {
return &ApplicationConfig{config: config{groups: groups}}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package envprovider
import (
"context"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/utils"
"os"
"strings"
)
type ConfigProvider struct {
appconfig.ProviderMeta
}
const dot = rune('.')
func (configProvider *ConfigProvider) Name() string {
return "environment-variable"
}
func (configProvider *ConfigProvider) Load(_ context.Context) (loadError error) {
defer func() {
if loadError != nil {
configProvider.Loaded = false
} else {
configProvider.Loaded = true
}
}()
flatSettings := make(map[string]interface{})
for _, e := range os.Environ() {
kv := strings.SplitN(e, "=", 2)
k := kv[0]
v := kv[1]
var runes []rune
for pos, char := range k {
if strings.Compare(string(char), "_") == 0 {
if pos>0 && strings.Compare(string(runes[pos-1]) , "_") != 0 {
runes = append(runes, dot)
} else {
runes = append(runes, char)
}
} else {
runes = append(runes, char)
}
}
flatSettings[string(runes)] = utils.ParseString(v)
}
unFlattenedSettings, loadError := appconfig.UnFlatten(flatSettings)
if loadError != nil {
return loadError
}
configProvider.Settings = unFlattenedSettings
return nil
}
func NewEnvProvider(precedence int) *ConfigProvider {
return &ConfigProvider{
ProviderMeta: appconfig.ProviderMeta{Precedence: precedence},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package fileprovider
import (
"context"
"embed"
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/appconfig/parser"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"io"
"os"
"path"
"strings"
)
var logger = log.New("Config.File")
type ConfigProvider struct {
appconfig.ProviderMeta
reader io.Reader
filepath string
}
func NewProvider(precedence int, filePath string, reader io.Reader) *ConfigProvider {
fileExt := strings.ToLower(path.Ext(filePath))
switch fileExt {
case ".yml", ".yaml":
return &ConfigProvider{
ProviderMeta: appconfig.ProviderMeta{Precedence: precedence},
reader: reader,
filepath: filePath,
}
//TODO: impl the following
/*
case ".ini":
return NewCachedLoader(NewINIFile(name, fileName, reader))
case ".json", ".json5":
return NewCachedLoader(NewJSONFile(name, fileName, reader))
case ".toml":
return NewCachedLoader(NewTOMLFile(name, fileName, reader))
case ".properties":
return NewCachedLoader(NewPropertiesFile(name, fileName, reader))
*/
default:
logger.Warnf("Unknown appconfig file extension: %s", fileExt)
return nil
}
}
func (configProvider *ConfigProvider) Name() string {
return fmt.Sprintf("file:%s", configProvider.filepath)
}
func (configProvider *ConfigProvider) Load(_ context.Context) (loadError error) {
defer func(){
if loadError != nil {
configProvider.Loaded = false
} else {
configProvider.Loaded = true
}
}()
encoded, loadError := io.ReadAll(configProvider.reader)
if loadError != nil {
return loadError
}
settings, loadError := parser.NewYamlPropertyParser()(encoded)
if loadError != nil {
return loadError
}
configProvider.Settings = settings
return nil
}
func NewFileProvidersFromBaseName(precedence int, baseName string, ext string, conf bootstrap.ApplicationConfig) (provider *ConfigProvider, exists bool) {
raw := conf.Value(appconfig.PropertyKeyConfigFileSearchPath)
var searchPaths []string
switch v := raw.(type) {
case string:
searchPaths = []string{v}
case []string:
searchPaths = v
case []interface{}:
searchPaths = make([]string, len(v))
for i, elem := range v {
if s, ok := elem.(string); ok {
searchPaths[i] = s
}
}
}
for _, dir := range searchPaths {
fullPath := path.Join(dir, baseName + "." + ext)
info, err := os.Stat(fullPath)
if !os.IsNotExist(err) && !info.IsDir() {
file, _ := os.Open(fullPath)
return NewProvider(precedence, fullPath, file), true
}
}
return nil, false
}
func NewEmbeddedFSProvider(precedence int, path string, fs embed.FS) (provider *ConfigProvider, exists bool) {
file, e := fs.Open(path)
if e != nil {
return nil, false
}
return NewProvider(precedence, path, file), true
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"embed"
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/appconfig/fileprovider"
"go.uber.org/fx"
"path/filepath"
"reflect"
)
const (
FxGroupBootstrap = "bootstrap-config"
FxGroupApplication = "application-config"
FxGroupDefaults = "default-config"
)
// FxEmbeddedDefaults returns a specialized fx.Option that take a given embed.FS and load *.yml as default properties
func FxEmbeddedDefaults(fs embed.FS, searchPaths...string) fx.Option {
return embeddedFileProviderFxOptions(FxGroupDefaults, fs, searchPaths...)
}
// FxEmbeddedApplicationAdHoc returns a specialized fx.Option that take a given embed.FS and load *.yml as application properties
func FxEmbeddedApplicationAdHoc(fs embed.FS, searchPaths...string) fx.Option {
return embeddedFileProviderFxOptions(FxGroupApplication, fs, searchPaths...)
}
// FxEmbeddedBootstrapAdHoc returns a specialized fx.Option that take a given embed.FS and load *.yml as bootstrap properties
func FxEmbeddedBootstrapAdHoc(fs embed.FS, searchPaths...string) fx.Option {
return embeddedFileProviderFxOptions(FxGroupBootstrap, fs, searchPaths...)
}
// FxProvideDefaults wraps given interface{} as a fx.Provide of appconfig.Provider with order of default properties
// Supported interface are
// - appconfig.Provider
// - a function that returns/create appconfig.Provider
func FxProvideDefaults(providers ...interface{}) fx.Option {
return providerFxOptions(FxGroupDefaults, providers)
}
// FxProvideApplicationAdHoc wraps given interface{} as a fx.Provide of appconfig.Provider with order of overriding application properties
// Supported interface are
// - appconfig.Provider
// - a function that returns/create appconfig.Provider
func FxProvideApplicationAdHoc(providers ...interface{}) fx.Option {
return providerFxOptions(FxGroupApplication, providers)
}
// FxProvideBootstrapAdHoc wraps given interface{} as a fx.Provide of appconfig.Provider with order of overriding bootstrap properties
// Supported interface are
// - appconfig.Provider
// - a function that returns/create appconfig.Provider
func FxProvideBootstrapAdHoc(providers ...interface{}) fx.Option {
return providerFxOptions(FxGroupBootstrap, providers)
}
func providerFxOptions(fxGroup string, providers []interface{}) fx.Option {
annotated := make([]interface{}, len(providers))
for i, p := range providers {
var target interface{}
switch provider := p.(type) {
case appconfig.Provider:
target = func() appconfig.Provider{
return provider
}
default:
v := reflect.ValueOf(p)
if v.Kind() != reflect.Func {
e := fmt.Errorf("invalid appconfig.FxProvide...() parameters. Support appconfig.Provider or a provide function, but got %T", p)
panic(e)
}
target = p
}
annotated[i] = fx.Annotated {
Group: fxGroup,
Target: target,
}
}
return fx.Provide(annotated...)
}
func embeddedFileProviderFxOptions(fxGroup string, fs embed.FS, searchPaths...string) fx.Option {
if len(searchPaths) == 0 {
searchPaths = []string{"."}
}
const ext = "yml"
providers := make([]interface{}, 0)
for _, searchPath := range searchPaths {
files, e := fs.ReadDir(searchPath)
if e != nil {
continue
}
for _, f := range files {
if !f.IsDir() || filepath.Ext(f.Name()) == ext {
providers = append(providers, fxEmbeddedFileProvider(fxGroup, filepath.Join(searchPath, f.Name()), fs))
}
}
}
return fx.Provide(providers...)
}
var embeddedFileCount int
func fxEmbeddedFileProvider(fxGroup string, filepath string, fs embed.FS) fx.Annotated {
fn := func() appconfig.Provider{
// Note order will be overwritten by corresponding provider group
// the precedence here is used to record the natural order within the group
provider, _ := fileprovider.NewEmbeddedFSProvider(embeddedFileCount, filepath, fs)
embeddedFileCount++
return provider
}
return fx.Annotated {
Group: fxGroup,
Target: fn,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"embed"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
const (
//preserve gap between different property sources to allow space for profile specific properties.
precedenceGap = 1000
//lower integer means higher precedence, therefore the list here is high to low in terms of precedence
_ = iota * precedenceGap
PrecedenceExternalAppContext
PrecedenceExternalDefaultContext
PrecedenceApplicationAdHoc
PrecedenceBootstrapAdHoc
PrecedenceCommandline
PrecedenceOSEnv
PrecedenceApplicationLocalFile
PrecedenceBootstrapLocalFile
PrecedenceDefault
)
var logger = log.New("Config")
//go:embed defaults-global.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "bootstrap endpoint",
Precedence: bootstrap.AppConfigPrecedence,
PriorityOptions: []fx.Option{
FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(
// Bootstrap groups and config
newCommandProviderGroup,
newOsEnvProviderGroup,
newBootstrapFileProviderGroup,
newDefaultProviderGroup,
newBootstrapAdHocProviderGroup,
newBootstrapConfig,
// Application file & adhoc
newApplicationFileProviderGroup,
newApplicationAdHocProviderGroup,
// App Config
newApplicationConfig,
newGlobalProperties,
),
},
}
// Use Entrypoint of appconfig package
func Use() {
bootstrap.Register(Module)
}
type bootstrapConfigDI struct {
fx.In
App *bootstrap.App
ProviderGroups []appconfig.ProviderGroup `group:"bootstrap-config"`
}
func newBootstrapConfig(di bootstrapConfigDI) *appconfig.BootstrapConfig {
var groups []appconfig.ProviderGroup
for _, g := range di.ProviderGroups {
if g != nil {
groups = append(groups, g)
}
}
bootstrapConfig := appconfig.NewBootstrapConfig(groups...)
if e := bootstrapConfig.Load(di.App.EagerGetApplicationContext(), false); e != nil {
panic(e)
}
return bootstrapConfig
}
type appConfigDIOut struct {
fx.Out
ACPtr *appconfig.ApplicationConfig
ACI bootstrap.ApplicationConfig
}
type appConfigDI struct {
fx.In
App *bootstrap.App
ProviderGroups []appconfig.ProviderGroup `group:"application-config"`
BootstrapConfig *appconfig.BootstrapConfig
}
// expose *appconfig.ApplicationConfig as both pointer and interface
func newApplicationConfig(di appConfigDI) appConfigDIOut {
var groups []appconfig.ProviderGroup
for _, g := range di.ProviderGroups {
if g != nil {
groups = append(groups, g)
}
}
for _, g := range di.BootstrapConfig.ProviderGroups() {
groups = append(groups, g)
}
applicationConfig := appconfig.NewApplicationConfig(groups...)
if e := applicationConfig.Load(di.App.EagerGetApplicationContext(), false); e != nil {
panic(e)
}
return appConfigDIOut{
ACPtr: applicationConfig,
ACI: applicationConfig,
}
}
func newGlobalProperties(cfg *appconfig.ApplicationConfig) bootstrap.Properties {
props := bootstrap.Properties{}
if e := cfg.Bind(&props, ""); e != nil {
panic(e)
}
return props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"go.uber.org/fx"
)
type adhocBootstrapDI struct {
fx.In
Providers []appconfig.Provider `group:"bootstrap-config"`
}
func newBootstrapAdHocProviderGroup(di adhocBootstrapDI) bootstrapProvidersOut {
order.SortStable(di.Providers, order.OrderedFirstCompare)
providers := make([]appconfig.Provider, 0)
for _, p := range di.Providers {
if p == nil {
continue
}
if reorder, ok := p.(appconfig.ProviderReorderer); ok {
reorder.Reorder(PrecedenceBootstrapAdHoc)
}
providers = append(providers, p)
}
return bootstrapProvidersOut {
ProviderGroup: appconfig.NewStaticProviderGroup(PrecedenceBootstrapAdHoc, providers...),
}
}
type adhocApplicationDI struct {
fx.In
Providers []appconfig.Provider `group:"application-config"`
}
func newApplicationAdHocProviderGroup(di adhocApplicationDI) appConfigProvidersOut {
order.SortStable(di.Providers, order.OrderedFirstCompare)
providers := make([]appconfig.Provider, 0)
for _, p := range di.Providers {
if p == nil {
continue
}
if reorder, ok := p.(appconfig.ProviderReorderer); ok {
reorder.Reorder(PrecedenceApplicationAdHoc)
}
providers = append(providers, p)
}
return appConfigProvidersOut {
ProviderGroup: appconfig.NewStaticProviderGroup(PrecedenceApplicationAdHoc, providers...),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/appconfig/cliprovider"
"github.com/cisco-open/go-lanai/pkg/appconfig/envprovider"
"github.com/cisco-open/go-lanai/pkg/appconfig/fileprovider"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
type bootstrapProvidersOut struct {
fx.Out
ProviderGroup appconfig.ProviderGroup `group:"bootstrap-config"`
}
type appConfigProvidersOut struct {
fx.Out
ProviderGroup appconfig.ProviderGroup `group:"application-config"`
}
/*********************
Bootstrap Groups
*********************/
func newCommandProviderGroup(execCtx *bootstrap.CliExecContext) bootstrapProvidersOut {
p := cliprovider.NewCobraProvider(PrecedenceCommandline, execCtx, "cli.flag.")
return bootstrapProvidersOut {
ProviderGroup: appconfig.NewStaticProviderGroup(PrecedenceCommandline, p),
}
}
func newOsEnvProviderGroup() bootstrapProvidersOut {
p := envprovider.NewEnvProvider(PrecedenceOSEnv)
return bootstrapProvidersOut {
ProviderGroup: appconfig.NewStaticProviderGroup(PrecedenceOSEnv, p),
}
}
func newBootstrapFileProviderGroup() bootstrapProvidersOut {
const name = "bootstrap"
const ext = "yml"
group := appconfig.NewProfileBasedProviderGroup(PrecedenceBootstrapLocalFile)
group.KeyFunc = func(profile string) string {
if profile == "" {
return name
}
return fmt.Sprintf("%s-%s", name, profile)
}
group.CreateFunc = func(name string, order int, conf bootstrap.ApplicationConfig) appconfig.Provider {
ptr, exists := fileprovider.NewFileProvidersFromBaseName(order, name, ext, conf)
if !exists || ptr == nil {
return nil
}
return ptr
}
group.ProcessFunc = func(ctx context.Context, providers []appconfig.Provider) []appconfig.Provider {
if len(providers) != 0 {
logger.WithContext(ctx).Infof("Found %d bootstrap configuration files", len(providers))
}
return providers
}
return bootstrapProvidersOut {
ProviderGroup: group,
}
}
type defaultProviderGroupDI struct {
fx.In
ExecCtx *bootstrap.CliExecContext
Providers []appconfig.Provider `group:"default-config"`
}
func newDefaultProviderGroup(di defaultProviderGroupDI) bootstrapProvidersOut {
p := cliprovider.NewStaticConfigProvider(PrecedenceDefault, di.ExecCtx)
providers := []appconfig.Provider{p}
for _, p := range di.Providers {
if p == nil {
continue
}
if reorder, ok := p.(appconfig.ProviderReorderer); ok {
reorder.Reorder(PrecedenceDefault)
}
providers = append(providers, p)
}
return bootstrapProvidersOut {
ProviderGroup: appconfig.NewStaticProviderGroup(PrecedenceDefault, providers...),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/appconfig/fileprovider"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
)
func newApplicationFileProviderGroup() appConfigProvidersOut {
const name = "application"
const ext = "yml"
group := appconfig.NewProfileBasedProviderGroup(PrecedenceApplicationLocalFile)
group.KeyFunc = func(profile string) string {
if profile == "" {
return name
}
return fmt.Sprintf("%s-%s", name, profile)
}
group.CreateFunc = func(name string, order int, conf bootstrap.ApplicationConfig) appconfig.Provider {
ptr, exists := fileprovider.NewFileProvidersFromBaseName(order, name, ext, conf)
if !exists || ptr == nil {
return nil
}
return ptr
}
group.ProcessFunc = func(ctx context.Context, providers []appconfig.Provider) []appconfig.Provider {
if len(providers) != 0 {
logger.WithContext(ctx).Infof("found %d application configuration files", len(providers))
}
return providers
}
return appConfigProvidersOut {
ProviderGroup: group,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"dario.cat/mergo"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"reflect"
"strconv"
"strings"
)
// Options the flatten options.
// By default: Delimiter = "."
type Options struct {
Delimiter string
}
// ProcessKeyFormat traverse given map in DFS fashion and apply given processor to each KV pair
func ProcessKeyFormat(nested map[string]interface{}, processor func(string, ...func(*Options)) string) (map[string]interface{}, error) {
result, err := processKeyFormat(nested, processor)
return result.(map[string]interface{}), err
}
func processKeyFormat(value interface{}, processor func(string, ...func(*Options)) string) (interface{}, error) {
switch value := value.(type) {
case map[string]interface{}:
result := make(map[string]interface{})
//if empty map, can't do anything
if reflect.DeepEqual(value, map[string]interface{}{}) {
return result, nil
}
for k, v := range value {
newKey := processor(k)
visitedValue, fe := processKeyFormat(v, processor)
if fe != nil {
return nil, fe
}
result[newKey] = visitedValue
}
return result, nil
case []interface{}:
//if empty slice
var result []interface{}
if reflect.DeepEqual(value, []interface{}{}) {
return result, nil
}
for _, v := range value {
visitedValue, fe := processKeyFormat(v, processor)
if fe != nil {
return nil, fe
}
result = append(result, visitedValue)
}
return result, nil
default:
return value, nil
}
}
func VisitEach(nested map[string]interface{}, apply func(string, interface{}) error, configures...func(*Options)) error {
opts := &Options{
Delimiter: ".",
}
for _, configure := range configures {
configure(opts)
}
return recursiveVisit("", nested, apply, opts)
}
//the recursive visit stops at the first error
func recursiveVisit(key string, value interface{}, apply func(string, interface{}) error, opts *Options) (err error) {
switch value := value.(type) {
case map[string]interface{}:
//if empty map, can't do anything
if reflect.DeepEqual(value, map[string]interface{}{}) {
return
}
for k, v := range value {
// create new key
newKey := k
if key != "" {
newKey = key + opts.Delimiter + newKey
}
fe := recursiveVisit(newKey, v, apply, opts)
if fe != nil {
err = fe
return
}
}
case []interface{}:
//if empty slice
if reflect.DeepEqual(value, []interface{}{}) {
return
}
for i, v := range value {
newKey := "[" + strconv.Itoa(i) + "]"
if key != "" {
newKey = key + newKey
}
fe := recursiveVisit(newKey, v, apply, opts)
if fe != nil {
err = fe
return
}
}
default:
err = apply(key, value)
}
return
}
type UfOptions struct {
Delimiter string
AppendSlice bool
}
// UnFlatten supports un-flattening keys with index like the following
// my-example.url[0]=https://example.com
// The indexed entries are treated like an unsorted list. The result will be a list but the order is not
// guaranteed to reflect the index order.
// A key with multiple index (a.b[0].c[0) is not supported
func UnFlatten(flat map[string]interface{}, configures...func(*UfOptions)) (nested map[string]interface{}, err error) {
opts := &UfOptions{
Delimiter: ".",
AppendSlice: true,
}
for _, configure := range configures {
configure(opts)
}
nested = make(map[string]interface{})
for k, v := range flat {
temp, e := uf(k, v, opts)
if e != nil {
return nil, errors.Wrap(e, "cannot un-flatten due to error in key: " + k)
}
err = mergo.Merge(&nested, temp, func(c *mergo.Config) {
c.AppendSlice = opts.AppendSlice
})
if err != nil {
return
}
}
return
}
func uf(k string, v interface{}, opts *UfOptions) (n interface{}, err error) {
indexOccurance := 0
n = v
keys := strings.Split(k, opts.Delimiter)
for i := len(keys) - 1; i >= 0; i-- {
currKey := keys[i]
temp := make(map[string]interface{})
bracketLeft := strings.Index(currKey, "[")
bracketRight := strings.Index(currKey, "]")
if bracketLeft > 0 && bracketRight == len(currKey) -1 {
index, e := strconv.Atoi(currKey[bracketLeft+1 : bracketRight])
if e != nil || index < 0 {
return nil, errors.Wrap(e, "key:"+" has index marker [], but the index is not valid integer.")
} else if indexOccurance > 0 {
return nil, errors.New("key:"+" has multiple index marker []. This is not supported")
} else {
currKey = currKey[0:bracketLeft]
temp[currKey] = []interface{}{n}
indexOccurance = indexOccurance + 1
}
} else {
temp[currKey] = n
}
n = temp
}
return n, nil
}
func UnFlattenKey(k string, configures...func(*Options)) []string {
opts := &Options{
Delimiter: ".",
}
for _, configure := range configures {
configure(opts)
}
return strings.Split(k, opts.Delimiter)
}
const dash = rune('-')
// NormalizeKey convert camelCase key to snake-case
func NormalizeKey(key string, configures...func(*Options)) string {
keys := UnFlattenKey(key, configures...)
result := ""
for i, key := range keys {
result = result + utils.CamelToSnakeCase(key)
if i < len(keys) - 1 {
result = result + "."
}
}
return result
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package parser
import (
"encoding/json"
)
func NewJSONPropertyParser() PropertyParser {
return func(encoded []byte) (map[string]interface{}, error) {
var m = make(map[string]interface{})
error := json.Unmarshal(encoded, &m)
return m, error
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package parser
import (
//"gopkg.in/yaml.v2"
"encoding/json"
"github.com/ghodss/yaml"
)
func NewYamlPropertyParser() PropertyParser {
return func(encoded []byte) (map[string]interface{}, error){
m := make(map[string]interface{})
encodedJson, e := yaml.YAMLToJSON(encoded) //need to do this because json marshal needs to work on map with string key. so only json marshal and unmarshal matches
if e != nil {
return m, e
}
e = json.Unmarshal(encodedJson, &m)
return m, e
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"context"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils/order"
)
type Provider interface {
order.Ordered
// Name is unique name of given provider, it also used as primary key in any mapping
Name() string
// Load load settings and should be idempotent. e.g. calling it multiple times should not affect loaded settings
Load(ctx context.Context) error
// GetSettings returns loaded settings. might be nil if not IsLoaded returns true
// The returned map should be un-flattened. i.e. flat.key=value should be stored as {"flat":{"key":"value"}}
GetSettings() map[string]interface{}
// IsLoaded should return true if Load is invoked at least once
IsLoaded() bool
// Reset delete loaded settings and reset IsLoaded flag
Reset()
}
type ProviderReorderer interface {
// Reorder set order
Reorder(int)
}
// ProviderGroup determines Providers based on given bootstrap.ApplicationConfig
type ProviderGroup interface {
order.Ordered
// Providers returns providers based on given config.
// This method should be idempotent. e.g. calling it multiple times with same config always returns identical slice
Providers(ctx context.Context, config bootstrap.ApplicationConfig) []Provider
// Reset should mark all providers unloaded
Reset()
}
/********************
Common Impl.
********************/
// ProviderMeta implements ProviderReorderer and partial ProviderMeta
type ProviderMeta struct {
Loaded bool //invalid if not loaded or during load
Settings map[string]interface{} //storage for the settings loaded by the auth
Precedence int //the precedence for which the settings will take effect.
}
func (m ProviderMeta) GetSettings() map[string]interface{} {
return m.Settings
}
func (m ProviderMeta) Order() int {
return m.Precedence
}
func (m ProviderMeta) IsLoaded() bool {
return m.Loaded
}
func (m *ProviderMeta) Reset() {
m.Loaded = false
m.Settings = nil
}
func (m *ProviderMeta) Reorder(order int) {
m.Precedence = order
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package appconfig
import (
"context"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"strings"
)
// StaticProviderGroup implements ProviderGroup, and holds fixed provider list
type StaticProviderGroup struct {
Precedence int
StaticProviders []Provider
}
func NewStaticProviderGroup(order int, providers ...Provider) *StaticProviderGroup {
return &StaticProviderGroup{
Precedence: order,
StaticProviders: providers,
}
}
func (g StaticProviderGroup) Order() int {
return g.Precedence
}
func (g StaticProviderGroup) Providers(_ context.Context, _ bootstrap.ApplicationConfig) []Provider {
return g.StaticProviders
}
func (g *StaticProviderGroup) Reset() {
for _, p := range g.StaticProviders {
p.Reset()
}
}
// DynamicProviderGroup implements ProviderGroup, and holds a sorted list of keys and their corresponding Provider.
// This type is typically used as embedded struct
type DynamicProviderGroup struct {
Precedence int
ProviderKeys []string // Provider should be sorted all time based on their provider's ordering
ProviderLookup map[string]Provider
ResolvedProviders []Provider
ProcessFunc func(context.Context, []Provider) []Provider // ProcessFunc is invoked before setting ResolvedProviders. Last chance to change
}
func NewDynamicProviderGroup(order int) *DynamicProviderGroup {
return &DynamicProviderGroup{
Precedence: order,
ProviderKeys: []string{},
ProviderLookup: map[string]Provider{},
}
}
func (g *DynamicProviderGroup) Order() int {
return g.Precedence
}
func (g *DynamicProviderGroup) Providers(ctx context.Context, _ bootstrap.ApplicationConfig) (providers []Provider) {
if g.ResolvedProviders != nil {
return g.ResolvedProviders
}
// we assume ProviderKeys are sorted already
// Note, we re-assign order of each providers starting with group's order and move backwards
for i, order := len(g.ProviderKeys)-1, g.Precedence; i >= 0; i-- {
p, ok := g.ProviderLookup[g.ProviderKeys[i]]
if !ok {
continue
}
providers = append(providers, p)
// re-assign order
if ro, ok := p.(ProviderReorderer); ok {
ro.Reorder(order)
}
order--
}
// process and return
if g.ProcessFunc != nil {
providers = g.ProcessFunc(ctx, providers)
}
g.ResolvedProviders = providers
return
}
func (g *DynamicProviderGroup) Reset() {
for _, p := range g.ProviderLookup {
p.Reset()
}
g.ResolvedProviders = nil
}
// ProfileBasedProviderGroup extends DynamicProviderGroup and implements ProviderGroup
// it provide base methods to determine Providers based on PropertyKeyActiveProfiles
type ProfileBasedProviderGroup struct {
DynamicProviderGroup
KeyFunc func(profile string) (key string)
CreateFunc func(name string, order int, conf bootstrap.ApplicationConfig) Provider
}
func NewProfileBasedProviderGroup(order int) *ProfileBasedProviderGroup {
return &ProfileBasedProviderGroup{
DynamicProviderGroup: *NewDynamicProviderGroup(order),
}
}
func (g *ProfileBasedProviderGroup) Providers(ctx context.Context, conf bootstrap.ApplicationConfig) (providers []Provider) {
profiles := resolveProfiles(conf)
// resolve names, create new providers if necessary
g.ProviderKeys = []string{}
names := map[string]struct{}{}
lenBefore := len(g.ProviderLookup)
for _, pf := range profiles {
name := g.KeyFunc(pf)
names[name] = struct{}{}
g.ProviderKeys = append(g.ProviderKeys, name)
if p, ok := g.ProviderLookup[name]; !ok || p == nil {
p = g.CreateFunc(name, g.Precedence, conf)
if p != nil {
g.ProviderLookup[name] = p
}
}
}
// cleanup ProviderLookup to prevent mem leak
if lenBefore != len(g.ProviderLookup) {
for k := range g.ProviderLookup {
if _, ok := names[k]; !ok {
delete(g.ProviderLookup, k)
}
}
// reset resolved providers too
g.ResolvedProviders = nil
}
return g.DynamicProviderGroup.Providers(ctx, conf)
}
func resolveProfiles(conf bootstrap.ApplicationConfig) (profiles []string) {
// active profiles
active, _ := conf.Value(PropertyKeyActiveProfiles).([]interface{})
for _, p := range active {
pStr, _ := p.(string)
pStr = strings.TrimSpace(pStr)
if pStr != "" {
profiles = append(profiles, pStr)
}
}
// additional profiles
additional, _ := conf.Value(PropertyKeyAdditionalProfiles).([]string)
for _, p := range additional {
p = strings.TrimSpace(p)
if p != "" {
profiles = append(profiles, p)
}
}
// default profiles
profiles = append(profiles, "") // add default profile
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package acm
import (
"context"
"github.com/aws/aws-sdk-go-v2/service/acm"
awsclient "github.com/cisco-open/go-lanai/pkg/aws"
)
type ClientFactory interface {
New(ctx context.Context, opts...func(opt *acm.Options)) (*acm.Client, error)
}
func NewClientFactory(loader awsclient.ConfigLoader) ClientFactory {
return &acmFactory{
configLoader: loader,
}
}
type acmFactory struct {
configLoader awsclient.ConfigLoader
}
func (f *acmFactory) New(ctx context.Context, opts...func(opt *acm.Options)) (*acm.Client, error) {
cfg, e := f.configLoader.Load(ctx)
if e != nil {
return nil, e
}
return acm.NewFromConfig(cfg, opts...), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package acm
import (
"context"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"go.uber.org/fx"
)
type regDI struct {
fx.In
HealthRegistrar health.Registrar `optional:"true"`
AcmClient *acm.Client
}
func RegisterHealth(di regDI) {
if di.HealthRegistrar == nil {
return
}
di.HealthRegistrar.MustRegister(&HealthIndicator{
AcmClient: di.AcmClient,
})
}
// HealthIndicator monitor ACM client status
type HealthIndicator struct {
AcmClient *acm.Client
}
func (i *HealthIndicator) Name() string {
return "aws.acm"
}
func (i *HealthIndicator) Health(ctx context.Context, _ health.Options) health.Health {
input := &acm.GetAccountConfigurationInput{}
if _, e := i.AcmClient.GetAccountConfiguration(ctx, input); e != nil {
logger.WithContext(ctx).Warnf("AWS ACM connection not available or identity invalid: %v", e)
return health.NewDetailedHealth(health.StatusUnknown, "AWS ACM connection not available or identity invalid", nil)
} else {
return health.NewDetailedHealth(health.StatusUp, "aws connect succeeded", nil)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package acm
import (
"github.com/aws/aws-sdk-go-v2/service/acm"
awsclient "github.com/cisco-open/go-lanai/pkg/aws"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
var logger = log.New("Aws")
var Module = &bootstrap.Module{
Name: "ACM",
Precedence: bootstrap.AwsPrecedence,
Options: []fx.Option{
fx.Provide(NewClientFactory),
fx.Provide(NewDefaultClient),
fx.Invoke(RegisterHealth),
},
}
// Use func, does nothing. Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
bootstrap.Register(awsclient.Module)
}
func NewDefaultClient(ctx *bootstrap.ApplicationContext, factory ClientFactory) (*acm.Client, error) {
return factory.New(ctx)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package aws
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"os"
)
const errTmpl = `invalid AWS configuration: %v`
type ConfigLoader interface {
Load(ctx context.Context, opts ...config.LoadOptionsFunc) (aws.Config, error)
}
// ConfigOverrideFunc used to override loaded aws.Config
type ConfigOverrideFunc func(cfg *aws.Config)
func NewConfigLoader(p Properties, customizers []config.LoadOptionsFunc, overrides []ConfigOverrideFunc) ConfigLoader {
return &PropertiesBasedConfigLoader{
Properties: &p,
Customizers: customizers,
Overrides: overrides,
}
}
type PropertiesBasedConfigLoader struct {
Properties *Properties
Customizers []config.LoadOptionsFunc
Overrides []ConfigOverrideFunc
}
func (l *PropertiesBasedConfigLoader) Load(ctx context.Context, opts ...config.LoadOptionsFunc) (cfg aws.Config, err error) {
extraOpts := append(l.Customizers, opts...)
opts = append([]config.LoadOptionsFunc{
WithBasicProperties(l.Properties),
WithCredentialsProperties(ctx, l.Properties, extraOpts...)},
extraOpts...,
)
cfg, err = LoadConfig(ctx, opts...)
if err != nil {
return
}
overrides := append([]ConfigOverrideFunc{OverrideConfigWithProperties(l.Properties)}, l.Overrides...)
for _, fn := range overrides {
fn(&cfg)
}
return
}
func LoadConfig(ctx context.Context, opts ...config.LoadOptionsFunc) (aws.Config, error) {
unnamedOpts := make([]func(*config.LoadOptions) error, len(opts))
for i := range opts {
unnamedOpts[i] = opts[i]
}
return config.LoadDefaultConfig(ctx, unnamedOpts...)
}
func WithBasicProperties(p *Properties) config.LoadOptionsFunc {
return func(opt *config.LoadOptions) error {
if len(p.Region) == 0 {
return fmt.Errorf(errTmpl, "Region is not set")
}
opt.Region = p.Region
//Note: Endpoint is set in OverrideConfigWithProperties
return nil
}
}
func WithCredentialsProperties(ctx context.Context, p *Properties, globalOpts ...config.LoadOptionsFunc) config.LoadOptionsFunc {
return func(opt *config.LoadOptions) error {
switch p.Credentials.Type {
case CredentialsTypeStatic:
opt.Credentials = credentials.NewStaticCredentialsProvider(p.Credentials.Id, p.Credentials.Secret, "static_auth")
case CredentialsTypeSTS:
var e error
if opt.Credentials, e = NewStsCredentialsProvider(ctx, p, globalOpts...); e != nil {
return fmt.Errorf(errTmpl, e)
}
default:
opt.Credentials = NewEnvCredentialsProvider()
}
return nil
}
}
// OverrideConfigWithProperties overrides given aws.Config with properties
func OverrideConfigWithProperties(props *Properties) ConfigOverrideFunc {
return func(cfg *aws.Config) {
// Note: At v1.27.27, aws.EndpointResolverWithOptionsFunc is deprecated. Each client would have its own EndpointResolver.
// All we need is to set the BaseEndpoint. But this property cannot be set via config.LoadOptionsFunc.
// So we set it via ConfigOverrideFunc
if len(props.Endpoint) != 0 {
cfg.BaseEndpoint = aws.String(props.Endpoint)
}
}
}
func NewStsCredentialsProvider(ctx context.Context, p *Properties, opts ...config.LoadOptionsFunc) (aws.CredentialsProvider, error) {
tokenPath := p.Credentials.TokenFile
if tokenPath == "" {
tokenPath = os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
}
roleArn := p.Credentials.RoleARN
if roleArn == "" {
roleArn = os.Getenv("AWS_ROLE_ARN")
}
// prepare config for STS
opts = append([]config.LoadOptionsFunc{WithBasicProperties(p)}, opts...)
cfg, e := LoadConfig(ctx, opts...)
if e != nil {
return nil, fmt.Errorf(`unable to prepare for STS credentials`)
}
overrides := []ConfigOverrideFunc{OverrideConfigWithProperties(p)}
for _, fn := range overrides {
fn(&cfg)
}
// create provider
client := sts.NewFromConfig(cfg)
provider := stscreds.NewWebIdentityRoleProvider(client, roleArn, stscreds.IdentityTokenFile(tokenPath), func(opts *stscreds.WebIdentityRoleOptions) {
opts.RoleSessionName = p.Credentials.RoleSessionName
})
return provider, nil
}
func NewEnvCredentialsProvider() aws.CredentialsProvider {
return aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) {
id := os.Getenv("AWS_ACCESS_KEY_ID")
if id == "" {
id = os.Getenv("AWS_ACCESS_KEY")
}
secret := os.Getenv("AWS_SECRET_ACCESS_KEY")
if secret == "" {
secret = os.Getenv("AWS_SECRET_KEY")
}
return aws.Credentials{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: os.Getenv("AWS_SESSION_TOKEN"),
}, nil
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package aws
import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
const FxGroup = `aws`
var Module = &bootstrap.Module{
Name: "AWS",
Precedence: bootstrap.AwsPrecedence,
Options: []fx.Option{
fx.Provide(BindAwsProperties, ProvideConfigLoader),
},
}
func FxCustomizerProvider(constructor interface{}) fx.Annotated {
return fx.Annotated{
Group: FxGroup,
Target: constructor,
}
}
type CfgLoaderDI struct {
fx.In
Properties Properties
Customizers []config.LoadOptionsFunc `group:"aws"`
Overrides []ConfigOverrideFunc `group:"aws"`
}
func ProvideConfigLoader(di CfgLoaderDI) ConfigLoader {
return NewConfigLoader(di.Properties, di.Customizers, di.Overrides)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package aws
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/pkg/errors"
)
const (
ConfigRootACM = "aws"
)
const (
CredentialsTypeStatic CredentialsType = `static`
CredentialsTypeSTS CredentialsType = `sts`
)
type CredentialsType string
// Properties describes common config used to consume AWS services
type Properties struct {
// Region for AWS client. Default: us-east-1
Region string `json:"region"`
// Endpoint for AWS client as "BaseEndpoint". Default: empty. Can be used to testing with consuming localstack
Endpoint string `json:"endpoint"`
// Credentials to be used to authenticate the AWS client
Credentials Credentials `json:"credentials"`
}
// Credentials defines the type of credentials to use for AWS
type Credentials struct {
//Type is one of static, env or sts. Defaults to env.
Type CredentialsType `json:"type"`
//The following is only relevant to static credential
//Id is the AWS_ACCESS_KEY_ID for the account
Id string `json:"id"`
//Secret is the AWS_SECRET_ACCESS_KEY
Secret string `json:"secret"`
//The follow is relevant to sts credentials (Used in EKS)
//RoleARN defines role to be assumed by application if omitted environment variable AWS_ROLE_ARN will be used
RoleARN string `json:"role-arn"`
//TokenFile is the path to the STS OIDC token file if omitted environment variable AWS_WEB_IDENTITY_TOKEN_FILE will be used
TokenFile string `json:"token-file"`
//RoleSessionName username to associate with session e.g. service account
RoleSessionName string `json:"role-session-name"`
}
func NewProperties() Properties {
return Properties{
Region: "us-east-1",
Credentials: Credentials{
Type: "env",
},
}
}
func BindAwsProperties(ctx *bootstrap.ApplicationContext) Properties {
props := NewProperties()
if err := ctx.Config().Bind(&props, ConfigRootACM); err != nil {
panic(errors.Wrap(err, "failed to bind acm.Properties"))
}
return props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils"
"go.uber.org/fx"
"os"
"sort"
"strings"
"sync"
"time"
)
/**************************
Bootstrapper
**************************/
var (
once sync.Once
bootstrapperInstance *Bootstrapper
)
type ContextOption func(ctx context.Context) context.Context
/**************************
Singleton Pattern
**************************/
// GlobalBootstrapper returns globally configured Bootstrapper.
// This bootstrapper is the one that being used by Execute, any package-level function works against this instance
func GlobalBootstrapper() *Bootstrapper {
return bootstrapper()
}
func bootstrapper() *Bootstrapper {
once.Do(func() {
bootstrapperInstance = NewBootstrapper()
})
return bootstrapperInstance
}
func Register(m *Module) {
bootstrapper().Register(m)
}
func AddOptions(options ...fx.Option) {
bootstrapper().AddOptions(options...)
}
func AddInitialAppContextOptions(options ...ContextOption) {
bootstrapper().AddInitialAppContextOptions(options...)
}
func AddStartContextOptions(options ...ContextOption) {
bootstrapper().AddStartContextOptions(options...)
}
func AddStopContextOptions(options ...ContextOption) {
bootstrapper().AddStopContextOptions(options...)
}
/**************************
Bootstrapper
**************************/
// Bootstrapper stores application configurations for bootstrapping
type Bootstrapper struct {
modules utils.GenericSet[*Module]
adhocModule *Module
initCtxOpts []ContextOption
startCtxOpts []ContextOption
stopCtxOpts []ContextOption
}
// NewBootstrapper create a new Bootstrapper.
// Note: "bootstrap" package uses Singleton patterns for application bootstrap. Calling this function directly is not recommended
//
// This function is exported for test packages to use
func NewBootstrapper() *Bootstrapper {
return &Bootstrapper{
modules: utils.NewGenericSet[*Module](),
adhocModule: newAnonymousModule(),
}
}
func (b *Bootstrapper) Register(m *Module) {
b.modules.Add(m)
}
func (b *Bootstrapper) AddOptions(options ...fx.Option) {
b.adhocModule.Options = append(b.adhocModule.Options, options...)
}
func (b *Bootstrapper) AddInitialAppContextOptions(options ...ContextOption) {
b.initCtxOpts = append(b.initCtxOpts, options...)
}
func (b *Bootstrapper) AddStartContextOptions(options ...ContextOption) {
b.startCtxOpts = append(b.startCtxOpts, options...)
}
func (b *Bootstrapper) AddStopContextOptions(options ...ContextOption) {
b.stopCtxOpts = append(b.stopCtxOpts, options...)
}
// EnableCliRunnerMode implements CliRunnerEnabler
func (b *Bootstrapper) EnableCliRunnerMode(runnerProviders ...interface{}) {
enableCliRunnerMode(b, runnerProviders)
}
func (b *Bootstrapper) NewApp(cliCtx *CliExecContext, priorityOptions []fx.Option, regularOptions []fx.Option) *App {
// create App
app := &App{
ctx: NewApplicationContext(b.initCtxOpts...),
startCtxOpts: b.startCtxOpts,
stopCtxOpts: b.stopCtxOpts,
}
// Decide default module
initModule := InitModule(cliCtx, app)
miscModules := MiscModules()
// Decide ad-hoc fx options
mainModule := newApplicationMainModule()
for _, o := range priorityOptions {
mainModule.PriorityOptions = append(mainModule.PriorityOptions, o)
}
for _, o := range regularOptions {
mainModule.Options = append(mainModule.Options, o)
}
// Expand and resolve modules
resolvedModules := b.modules.Copy()
resolvedModules.Add(initModule, mainModule, b.adhocModule)
resolvedModules.Add(miscModules...)
for changed := true; changed; {
before := len(resolvedModules)
for module := range resolvedModules {
resolvedModules.Add(module.Modules...)
}
changed = before != len(resolvedModules)
}
modules := resolvedModules.Values()
sort.SliceStable(modules, func(i, j int) bool { return modules[i].Precedence < modules[j].Precedence })
// add priority options first
var options []fx.Option
for _, m := range modules {
options = append(options, m.PriorityOptions...)
}
// add other options later
for _, m := range modules {
options = append(options, m.Options...)
}
// create fx.App, which will kick off all fx options
app.App = fx.New(options...)
return app
}
/**************************
Application
**************************/
type App struct {
*fx.App
ctx *ApplicationContext
startCtxOpts []ContextOption
stopCtxOpts []ContextOption
}
// EagerGetApplicationContext returns the global ApplicationContext before it becomes available for dependency injection
// Important: packages should typically get ApplicationContext via fx's dependency injection,
//
// which internal application config are guaranteed.
// Only packages involved in priority bootstrap (appconfig, consul, vault, etc)
// should use this function for logging purpose
func (app *App) EagerGetApplicationContext() *ApplicationContext {
return app.ctx
}
func (app *App) Run() {
// to be revised:
// 1. (Solved) Support Timeout in bootstrap.Context
// 2. (Solved) Restore logging
var cancel context.CancelFunc
done := app.Done()
startCtx := app.ctx.Context
for _, opt := range app.startCtxOpts {
startCtx = opt(startCtx)
}
// This is so that we know that the context in the life cycle hook is the child of bootstrap context
startCtx, cancel = context.WithTimeout(startCtx, app.StartTimeout())
defer cancel()
// log error and exit
if err := app.Start(startCtx); err != nil {
logger.WithContext(startCtx).Errorf("Failed to start up: %v", err)
exit(1)
}
// this line blocks until application shutting down
printSignal(<-done)
// shutdown sequence
stopCtx := context.WithValue(app.ctx.Context, ctxKeyStopTime, time.Now().UTC())
for _, opt := range app.stopCtxOpts {
stopCtx = opt(stopCtx)
}
stopCtx, cancel = context.WithTimeout(stopCtx, app.StopTimeout())
defer cancel()
if err := app.Stop(stopCtx); err != nil {
logger.WithContext(stopCtx).Errorf("Shutdown with Error: %v", err)
exit(1)
}
}
func printSignal(signal os.Signal) {
logger.Infof(strings.ToUpper(signal.String()))
}
func exit(code int) {
os.Exit(code)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/utils"
"path"
"strings"
"time"
)
var (
// to be overridden by -ldflags
BuildVersion = "Unknown"
BuildTime = time.Now().Format(utils.ISO8601Seconds)
BuildHash = "Unknown"
BuildDeps = "github.com/cisco-open/go-lanai@main"
)
var (
BuildInfo = BuildInfoMetadata{
Version: BuildVersion,
BuildTime: utils.ParseTime(utils.ISO8601Seconds, BuildTime),
Hash: BuildHash,
Modules: ModuleBuildInfoMap{},
}
BuildInfoMap map[string]interface{}
)
type BuildInfoResolver interface {
Resolve() BuildInfoMetadata
}
func init() {
_ = (&BuildInfo.Modules).UnmarshalText([]byte(BuildDeps))
BuildInfoMap = BuildInfo.ToMap()
}
type ModuleBuildInfo struct {
Path string `json:"path"`
Version string `json:"version"`
}
type ModuleBuildInfoMap map[string]ModuleBuildInfo
func (m *ModuleBuildInfoMap) UnmarshalText(text []byte) error {
mods := strings.Split(string(text), ",")
modules := ModuleBuildInfoMap{}
for _, v := range mods {
tokens := strings.SplitN(strings.TrimSpace(v), "@", 2)
if len(tokens) < 2 {
continue
}
name := path.Base(tokens[0])
modules[name] = ModuleBuildInfo{
Path: tokens[0],
Version: tokens[1],
}
}
*m = modules
return nil
}
type BuildInfoMetadata struct {
Version string `json:"version"`
BuildTime time.Time `json:"build-time"`
Hash string `json:"hash"`
Modules ModuleBuildInfoMap `json:"modules,omitempty"`
}
func (m *BuildInfoMetadata) ToMap() map[string]interface{} {
data, e := json.Marshal(m)
if e != nil {
return map[string]interface{}{}
}
ret := map[string]interface{}{}
if e := json.Unmarshal(data, &ret); e != nil {
return map[string]interface{}{}
}
return ret
}
const (
propPrefix = "info.app"
)
type buildInfoProperties struct {
Version string `json:"version"`
ShowDetails bool `json:"show-build-info"`
}
type defaultBuildInfoResolver struct {
appCtx *ApplicationContext
properties buildInfoProperties
}
func newDefaultBuildInfoResolver(appCtx *ApplicationContext) *defaultBuildInfoResolver {
resolver := defaultBuildInfoResolver{
appCtx: appCtx,
}
_ = appCtx.Config().Bind(&resolver.properties, propPrefix)
return &resolver
}
func (r defaultBuildInfoResolver) Resolve() BuildInfoMetadata {
info := BuildInfo
if r.properties.Version != "" {
info.Version = r.properties.Version
}
/**
* DE9198: remove the build info from the version unless info.app.show-build-info=true
* @return
*/
if !r.properties.ShowDetails {
info.Version = strings.SplitN(info.Version, "-", 2)[0]
}
return info
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"go.uber.org/fx"
)
const (
FxCliRunnerGroup = "bootstrap_cli_runner"
CliRunnerModuleName = "CLI Runner"
)
type CliRunner func(ctx context.Context) error
func (r CliRunner) WithOrder(order int) OrderedCliRunner {
return OrderedCliRunner{
Precedence: order,
CliRunner: r,
}
}
type OrderedCliRunner struct {
Precedence int
CliRunner CliRunner
}
func (r OrderedCliRunner) Order() int {
return r.Precedence
}
type CliRunnerEnabler interface {
// EnableCliRunnerMode see bootstrap.EnableCliRunnerMode
EnableCliRunnerMode(runnerProviders ...interface{})
}
// CliRunnerLifecycleHooks provide instrumentation around CliRunners
type CliRunnerLifecycleHooks interface {
Before(ctx context.Context, runner CliRunner) context.Context
After(ctx context.Context, runner CliRunner, err error) context.Context
}
// EnableCliRunnerMode should be called before Execute(), otherwise it won't run.
// "runnerProviders" are standard FX lifecycle functions that typically used with fx.Provide(...)
// signigure of "runnerProviders", but it should returns CliRunner or OrderedCliRunner, otherwise it won't run
//
// example of runner provider:
//
// func myRunner(di OtherDependencies) CliRunner {
// return func(ctx context.Context) error {
// // Do your stuff
// return err
// }
// }
//
// example of ordered runner provider:
//
// func myRunner(di OtherDependencies) OrderedCliRunner {
// return bootstrap.OrderedCliRunner{
// Precedence: 0,
// CliRunner: func(ctx context.Context) error {
// // Do your stuff
// return err
// },
// }
// }
//
// Using this pattern guarantees following things:
// 1. The application is automatically shutdown after all lifecycle hooks finished
// 2. The runner funcs are run after all other fx.Invoke
// 3. All other "OnStop" are executed regardless if any hook function returns error (graceful shutdown)
// 4. If any hook functions returns error, it reflected as non-zero process exit code
// 5. Each cli runner are separately traced if tracing is enabled
// 6. Any CliRunner without order is considered as having order 0
//
// Note: calling this function repeatedly would override previous invocation (i.e. only the last invocation takes effect)
func EnableCliRunnerMode(runnerProviders ...interface{}) {
enableCliRunnerMode(bootstrapper(), runnerProviders)
}
func newCliRunnerModule() *Module {
return &Module{
Name: CliRunnerModuleName,
Precedence: CommandLineRunnerPrecedence,
Options: []fx.Option{fx.Invoke(cliRunnerExec)},
}
}
func enableCliRunnerMode(b *Bootstrapper, runnerProviders []interface{}) {
// first find existing runner module or register one
var cliRunnerModule *Module
LOOP:
for m := range b.modules {
if m != nil && m.Name == CliRunnerModuleName {
cliRunnerModule = m
break LOOP
}
}
if cliRunnerModule == nil {
cliRunnerModule = newCliRunnerModule()
b.Register(cliRunnerModule)
}
// create annotated providers and add to module
providers := make([]interface{}, len(runnerProviders))
for i, provider := range runnerProviders {
providers[i] = fx.Annotated{
Group: FxCliRunnerGroup,
Target: provider,
}
}
cliRunnerModule.Options = append(cliRunnerModule.Options, fx.Provide(providers...))
}
type cliDI struct {
fx.In
Hooks []CliRunnerLifecycleHooks `group:"bootstrap_cli_runner"`
Runners []CliRunner `group:"bootstrap_cli_runner"`
OrderedRunners []OrderedCliRunner `group:"bootstrap_cli_runner"`
}
func cliRunnerExec(lc fx.Lifecycle, shutdowner fx.Shutdowner, di cliDI) {
order.SortStable(di.Hooks, order.OrderedFirstCompare)
runners := make([]OrderedCliRunner, len(di.Runners), len(di.Runners)+len(di.OrderedRunners))
for i := range di.Runners {
runners[i] = di.Runners[i].WithOrder(0)
}
for i := range di.OrderedRunners {
runners = append(runners, di.OrderedRunners[i])
}
order.SortStable(runners, order.OrderedFirstCompare)
var err error
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
for _, runner := range runners {
c := ctx
// before hook
for _, before := range di.Hooks {
c = before.Before(c, runner.CliRunner)
}
// run
err = runner.CliRunner(c)
// after hook
for _, after := range di.Hooks {
c = after.After(c, runner.CliRunner, err)
}
if err != nil {
break
}
}
// we delay error reporting to OnStop
return shutdowner.Shutdown()
},
OnStop: func(ctx context.Context) error {
return err
},
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"context"
"time"
)
const (
PropertyKeyApplicationName = "application.name"
)
type startTimeCtxKey struct{}
type stopTimeCtxKey struct{}
var (
ctxKeyStartTime = startTimeCtxKey{}
ctxKeyStopTime = stopTimeCtxKey{}
)
type ApplicationConfig interface {
Value(key string) interface{}
Bind(target interface{}, prefix string) error
}
// ApplicationContext is a Context carries addition data for application.
// delegates all other context calls to the embedded Context.
type ApplicationContext struct {
context.Context
config ApplicationConfig
}
func NewApplicationContext(opts ...ContextOption) *ApplicationContext {
ctx := context.Background()
for _, fn := range opts {
ctx = fn(ctx)
}
return &ApplicationContext{
Context: context.WithValue(ctx, ctxKeyStartTime, time.Now().UTC()),
}
}
func (c *ApplicationContext) Config() ApplicationConfig {
return c.config
}
func (c *ApplicationContext) Name() string {
name := c.Value(PropertyKeyApplicationName)
if name == nil {
return "lanai"
}
if n, ok := name.(string); ok {
return n
}
return "lanai"
}
/**************************
context.Context Interface
**************************/
func (_ *ApplicationContext) String() string {
return "application context"
}
func (c *ApplicationContext) Value(key interface{}) interface{} {
if c.config == nil {
return c.Context.Value(key)
}
switch key.(type) {
case string:
if ret := c.config.Value(key.(string)); ret != nil {
return ret
}
}
return c.Context.Value(key)
}
/**********************
* unexported methods
***********************/
func (c *ApplicationContext) withContext(parent context.Context) *ApplicationContext {
c.Context = parent
return c
}
func (c *ApplicationContext) withValue(k, v interface{}) *ApplicationContext {
return c.withContext(context.WithValue(c.Context, k, v))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx/fxevent"
"strings"
)
// fxPrinter implements fx.Printer (deprecated) and fxevent.Logger
type fxPrinter struct {
logger log.Logger
appCtx *ApplicationContext
}
func provideFxLogger(app *App) fxevent.Logger {
return newFxPrinter(logger, app)
}
func newFxPrinter(logger log.Logger, app *App) *fxPrinter {
return &fxPrinter{
logger: logger,
appCtx: app.ctx,
}
}
func (l *fxPrinter) logf(msg string, args ...interface{}) {
logger.WithContext(l.appCtx).Infof(msg, args...)
}
func (l *fxPrinter) Printf(s string, v ...interface{}) {
logger.WithContext(l.appCtx).Infof(s, v...)
}
func (l *fxPrinter) LogEvent(event fxevent.Event) {
switch e := event.(type) {
case *fxevent.OnStartExecuting:
//logger.WithContext(l.appCtx).Debugf("HOOK OnStart\t\t%s executing (caller: %s)", e.FunctionName, e.CallerName)
case *fxevent.OnStartExecuted:
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("HOOK OnStart\t\t%s called by %s failed in %s: %v", e.FunctionName, e.CallerName, e.Runtime, e.Err)
} //else {
//logger.WithContext(l.appCtx).Debugf("HOOK OnStart\t\t%s called by %s ran successfully in %s", e.FunctionName, e.CallerName, e.Runtime)
//}
case *fxevent.OnStopExecuting:
logger.WithContext(l.appCtx).Debugf("HOOK OnStop\t\t%s executing (caller: %s)", e.FunctionName, e.CallerName)
case *fxevent.OnStopExecuted:
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("HOOK OnStop\t\t%s called by %s failed in %s: %v", e.FunctionName, e.CallerName, e.Runtime, e.Err)
} //else {
//logger.WithContext(l.appCtx).Debugf("HOOK OnStop\t\t%s called by %s ran successfully in %s", e.FunctionName, e.CallerName, e.Runtime)
//}
case *fxevent.Supplied:
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("ERROR\tFailed to supply %v: %v", e.TypeName, e.Err)
} else {
logger.WithContext(l.appCtx).Infof("SUPPLY\t%v", e.TypeName)
}
case *fxevent.Provided:
for _, rtype := range e.OutputTypeNames {
logger.WithContext(l.appCtx).Infof("PROVIDE\t%v <= %v", rtype, e.ConstructorName)
}
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("Error after options were applied: %v", e.Err)
}
case *fxevent.Invoking:
logger.WithContext(l.appCtx).Debugf("INVOKE\t\t%s", e.FunctionName)
case *fxevent.Invoked:
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("ERROR\t\tfx.Invoke(%v) called from:\n%+vFailed: %v", e.FunctionName, e.Trace, e.Err)
}
case *fxevent.Stopping:
logger.WithContext(l.appCtx).Infof("%v", strings.ToUpper(e.Signal.String()))
case *fxevent.Stopped:
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("ERROR\t\tFailed to stop cleanly: %v", e.Err)
}
case *fxevent.RollingBack:
logger.WithContext(l.appCtx).Warnf("ERROR\t\tStart failed, rolling back: %v", e.StartErr)
case *fxevent.RolledBack:
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("ERROR\t\tCouldn't roll back cleanly: %v", e.Err)
}
case *fxevent.Started:
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("ERROR\t\tFailed to start: %v", e.Err)
} else {
logger.WithContext(l.appCtx).Infof("RUNNING")
}
case *fxevent.LoggerInitialized:
if e.Err != nil {
logger.WithContext(l.appCtx).Warnf("ERROR\t\tFailed to initialize custom logger: %+v", e.Err)
} else {
logger.WithContext(l.appCtx).Infof("LOGGER\tInitialized custom logger from %v", e.ConstructorName)
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"go.uber.org/fx"
)
const (
LowestPrecedence = int(^uint(0) >> 1) // max int
HighestPrecedence = -LowestPrecedence - 1 // min int
FrameworkModulePrecedenceBandwidth = 799
FrameworkModulePrecedence = LowestPrecedence - 200*(FrameworkModulePrecedenceBandwidth+1)
AnonymousModulePrecedence = FrameworkModulePrecedence - 1
PriorityModulePrecedence = HighestPrecedence + 1
)
const (
_ = FrameworkModulePrecedence + iota*(FrameworkModulePrecedenceBandwidth+1)
AppConfigPrecedence
TracingPrecedence
ActuatorPrecedence
ConsulPrecedence
VaultPrecedence
AwsPrecedence
TlsConfigPrecedence
RedisPrecedence
DatabasePrecedence
KafkaPrecedence
OpenSearchPrecedence
WebPrecedence
SecurityPrecedence
DebugPrecedence
ServiceDiscoveryPrecedence
DistributedLockPrecedence
TenantHierarchyAccessorPrecedence
TenantHierarchyLoaderPrecedence
TenantHierarchyModifierPrecedence
HttpClientPrecedence
SecurityIntegrationPrecedence
SwaggerPrecedence
StartupSummaryPrecedence
// CommandLineRunnerPrecedence invocation should happen after everything else, in case it needs functionality from any other modules
CommandLineRunnerPrecedence
)
type Module struct {
Name string
// Precedence basically govern the order or invokers between different Modules
Precedence int
// PriorityOptions are fx.Options applied before any regular Options
PriorityOptions []fx.Option
// Options is a collection fx.Option: fx.Provide, fx.Invoke, etc.
Options []fx.Option
// Modules is a collection of *Module that will also be initialized.
// They are not necessarily sub-modules. During bootstrapping, all modules are flattened and Precedence are calculated at the end
Modules []*Module
}
// newAnonymousModule has lower precedence than framework modules.
// It is used to hold options registered via AddOptions()
func newAnonymousModule() *Module {
return &Module{
Name: "anonymous",
Precedence: AnonymousModulePrecedence,
}
}
// newApplicationMainModule application main module has the highest precedence.
// It is used to hold options passed in via NewAppCmd()
func newApplicationMainModule() *Module {
return &Module{
Name: "main",
Precedence: PriorityModulePrecedence,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
"time"
)
var logger = log.New("Bootstrap")
// InitModule returns the module that would run with highest priority
func InitModule(cliCtx *CliExecContext, app *App) *Module {
return &Module{
Precedence: HighestPrecedence,
PriorityOptions: []fx.Option{
fx.WithLogger(provideFxLogger),
fx.Supply(cliCtx),
fx.Supply(app),
fx.Provide(provideApplicationContext),
fx.Provide(provideBuildInfoResolver),
fx.Invoke(bootstrap),
},
}
}
// MiscModules returns the module that would run with various precedence
func MiscModules() []*Module {
return []*Module{
{
Precedence: StartupSummaryPrecedence,
Options: []fx.Option{
fx.Invoke(startupTiming), // startup need to be run at last
},
},
{
Precedence: HighestPrecedence,
PriorityOptions: []fx.Option{
// shutdown timing need to be run at last
// note that fx.Hook.OnStop is run in reversed order
fx.Invoke(shutdownTiming),
},
},
}
}
type noopAppConfig struct{}
func (c noopAppConfig) Value(_ string) interface{} {
return nil
}
func (c noopAppConfig) Bind(_ interface{}, _ string) error {
return nil
}
type appCtxDI struct {
fx.In
App *App
AppConfig ApplicationConfig `optional:"true"`
}
func provideApplicationContext(di appCtxDI) *ApplicationContext {
di.App.ctx.config = di.AppConfig
if di.App.ctx.config == nil {
logger.WithContext(di.App.ctx).Warnf(`bootstrap.ApplicationConfig is not available`)
di.App.ctx.config = noopAppConfig{}
}
return di.App.ctx
}
func provideBuildInfoResolver(appCtx *ApplicationContext) BuildInfoResolver {
return newDefaultBuildInfoResolver(appCtx)
}
func bootstrap(lc fx.Lifecycle, ac *ApplicationContext) {
logProperties := &log.Properties{}
err := ac.config.Bind(logProperties, "log")
if err == nil {
err = log.UpdateLoggingConfiguration(logProperties)
}
if err != nil {
logger.Error("Error updating logging configuration", "error", err)
}
lc.Append(fx.Hook{
OnStart: func(_ context.Context) error {
logger.WithContext(ac).Info("On Application Start") //nolint:contextcheck
return nil
},
})
}
func startupTiming(lc fx.Lifecycle, appCtx *ApplicationContext) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if t, ok := ctx.Value(ctxKeyStartTime).(time.Time); ok {
elapsed := time.Now().Sub(t).Truncate(time.Millisecond)
logger.WithContext(ctx).Infof("Started %s in %v", appCtx.Name(), elapsed)
}
return nil
},
})
}
func shutdownTiming(lc fx.Lifecycle, appCtx *ApplicationContext) {
lc.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
if t, ok := ctx.Value(ctxKeyStopTime).(time.Time); ok {
elapsed := time.Now().Sub(t).Truncate(time.Millisecond)
logger.WithContext(ctx).Infof("Stopped %s in %v", appCtx.Name(), elapsed)
}
return nil
},
})
}
/*
Copyright © 2020 NAME HERE <EMAIL ADDRESS>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package bootstrap
import (
"context"
"fmt"
"github.com/spf13/cobra"
"go.uber.org/fx"
"os"
"regexp"
"strconv"
)
const (
CliFlagActiveProfile = "active-profiles"
CliFlagAdditionalProfile = "additional-profiles"
CliFlagConfigSearchPath = "config-search-path"
CliFlagDebug = "debug"
EnvVarDebug = "DEBUG"
)
var (
argsPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9\-._]+=.*`)
// rootCmd represents the base command when called without any subcommands
// Note: when running app as `./app --flag1 value1 --flag2 value2 -- any-thing...`
// the values after bare `--` are passed in as args. we could use it as CLI properties assignment
rootCmd = &cobra.Command{
Short: "A go-lanai based application.",
Long: "This is a go-lanai based application.",
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
Args: func(cmd *cobra.Command, args []string) error {
for _, arg := range args {
if !argsPattern.MatchString(arg) {
return fmt.Errorf(`CLI properties should be in format of "property-path=value", but got "%s"`, arg)
}
}
return nil
},
}
cliCtx = CliExecContext{}
)
type CliExecContext struct {
Cmd *cobra.Command
ActiveProfiles []string
AdditionalProfiles []string
ConfigSearchPaths []string
Debug bool
Args []string
}
func init() {
// config flags
rootCmd.PersistentFlags().StringSliceVarP(&cliCtx.ActiveProfiles, CliFlagActiveProfile, "P", []string{},
`Comma separated active profiles. Override property "application.profiles.active"`)
rootCmd.PersistentFlags().StringSliceVarP(&cliCtx.AdditionalProfiles, CliFlagAdditionalProfile, "p", []string{}, // small letter p instead of capital P
`Comma separated additional profiles. Set property "application.profiles.additional". Additional profiles is added to active profiles`)
rootCmd.PersistentFlags().StringSliceVarP(&cliCtx.ConfigSearchPaths, CliFlagConfigSearchPath, "c", []string{},
`Comma separated paths. Override property "config.file.search-path"`)
rootCmd.PersistentFlags().BoolVar(&cliCtx.Debug, CliFlagDebug, false,
fmt.Sprintf(`Boolean that toggles debug features. Override EnvVar "%s"`, EnvVarDebug))
// parse env vars
parseEnvVars()
}
func parseEnvVars() {
if v, ok := os.LookupEnv(EnvVarDebug); ok {
if b, e := strconv.ParseBool(v); e == nil {
cliCtx.Debug = b
} else {
cliCtx.Debug = true
}
}
}
// DebugEnabled returns false by default, unless environment variable DEBUG is set or application start with --debug
func DebugEnabled() bool {
return cliCtx.Debug
}
// AddStringFlag should be called before Execute() to register flags that are supported
func AddStringFlag(flagVar *string, name string, defaultValue string, usage string) {
rootCmd.PersistentFlags().StringVar(flagVar, name, defaultValue, usage)
}
func AddBoolFlag(flagVar *bool, name string, defaultValue bool, usage string) {
rootCmd.PersistentFlags().BoolVar(flagVar, name, defaultValue, usage)
}
// Execute run globally configured application.
// "globally configured" means Module and fx.Options added via package level functions. e.g. Register or AddOptions
// It adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
logger.Errorf("%v", err)
os.Exit(1)
}
}
// ExecuteContainedApp is similar to Execute, but run with a separately configured Bootstrapper.
// In this mode, the bootstrapping process ignore any globally configured modules and options.
// This is usually called by test framework. Service developers normally don't use it directly
func ExecuteContainedApp(ctx context.Context, b *Bootstrapper) {
ctx = contextWithBootstrapper(ctx, b)
if err := rootCmd.ExecuteContext(ctx); err != nil {
logger.Errorf("%v", err)
os.Exit(1)
}
}
type CliOptions func(cmd *cobra.Command)
func NewAppCmd(appName string, priorityOptions []fx.Option, regularOptions []fx.Option, cliOptions ...CliOptions) {
rootCmd.Use = appName
// To add more cmd. Declare the cmd as a variable similar to rootCmd. And add it to rootCmd here.
for _, f := range cliOptions {
f(rootCmd)
}
// Configure Run function
rootCmd.Run = func(cmd *cobra.Command, args []string) {
// make a copy of cli exec context
execCtx := cliCtx
execCtx.Cmd = cmd
execCtx.Args = args
b, ok := cmd.Context().Value(ctxKeyBootstrapper).(*Bootstrapper)
if !ok || b == nil {
b = bootstrapper()
}
b.NewApp(&execCtx, priorityOptions, regularOptions).Run()
}
}
/********************
Cmd Context
********************/
type bootstrapperCtxKey struct{}
var ctxKeyBootstrapper = bootstrapperCtxKey{}
type bootstrapContext struct {
context.Context
b *Bootstrapper
}
func contextWithBootstrapper(parent context.Context, b *Bootstrapper) context.Context {
return &bootstrapContext{
Context: parent,
b: b,
}
}
func (c *bootstrapContext) Value(key interface{}) interface{} {
switch {
case key == ctxKeyBootstrapper:
return c.b
}
return c.Context.Value(key)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package certsinit
// Initialize certificate manager with various of certificate sources
package certsinit
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
filecerts "github.com/cisco-open/go-lanai/pkg/certs/source/file"
"go.uber.org/fx"
"io"
)
const PropertiesPrefix = `certificates`
var Module = &bootstrap.Module{
Name: "certs",
Precedence: bootstrap.TlsConfigPrecedence,
Options: []fx.Option{
fx.Provide(BindProperties, ProvideDefaultManager),
fx.Provide(
filecerts.FxProvider(),
),
fx.Invoke(RegisterManagerLifecycle),
},
}
func Use() {
bootstrap.Register(Module)
}
type mgrDI struct {
fx.In
AppCfg bootstrap.ApplicationConfig
Props certs.Properties
Factories []certs.SourceFactory `group:"certs"`
}
func ProvideDefaultManager(di mgrDI) (certs.Manager, certs.Registrar) {
reg := certs.NewDefaultManager(func(mgr *certs.DefaultManager) {
mgr.ConfigLoaderFunc = di.AppCfg.Bind
mgr.Properties = di.Props
})
for _, f := range di.Factories {
if f != nil {
reg.MustRegister(f)
}
}
return reg, reg
}
// BindProperties create and bind SessionProperties, with a optional prefix
func BindProperties(appCfg bootstrap.ApplicationConfig) certs.Properties {
props := certs.NewProperties()
if e := appCfg.Bind(props, PropertiesPrefix); e != nil {
panic(fmt.Errorf("failed to bind certificate properties: %v", e))
}
return *props
}
func RegisterManagerLifecycle(lc fx.Lifecycle, m certs.Manager) {
lc.Append(fx.StopHook(func(context.Context) error {
if closer, ok := m.(io.Closer); ok {
return closer.Close()
}
return nil
}))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package certs
import (
"context"
"encoding/json"
"fmt"
"io"
"sync"
)
type DefaultManager struct {
sync.Mutex
Properties Properties
ConfigLoaderFunc func(target interface{}, configPath string) error
factories map[SourceType]SourceFactory
activeSources map[SourceType][]Source
}
func NewDefaultManager(opts ...func(mgr *DefaultManager)) *DefaultManager {
mgr := DefaultManager{
Properties: *NewProperties(),
factories: make(map[SourceType]SourceFactory),
activeSources: make(map[SourceType][]Source),
}
for _, fn := range opts {
fn(&mgr)
}
return &mgr
}
func (m *DefaultManager) Register(items ...interface{}) error {
for _, item := range items {
if e := m.register(item); e != nil {
return e
}
}
return nil
}
func (m *DefaultManager) MustRegister(items ...interface{}) {
if e := m.Register(items...); e != nil {
panic(e)
}
}
func (m *DefaultManager) Source(ctx context.Context, opts ...Options) (Source, error) {
opt := Option{}
for _, fn := range opts {
fn(&opt)
}
srcCfg, e := m.resolveSourceConfig(&opt)
if e != nil {
return nil, e
}
m.Lock()
defer m.Unlock()
factory, ok := m.factories[srcCfg.Type]
if !ok {
return nil, fmt.Errorf("unsupported TLS source: %s", srcCfg.Type)
}
src, e := factory.LoadAndInit(ctx, func(src *SourceConfig) {
src.RawConfig = srcCfg.RawConfig
})
if e != nil {
return nil, e
}
sources, _ := m.activeSources[srcCfg.Type]
sources = append(sources, src)
m.activeSources[srcCfg.Type] = sources
return src, nil
}
func (m *DefaultManager) Close() error {
for _, sources := range m.activeSources {
for _, src := range sources {
if closer, ok := src.(io.Closer); ok {
_ = closer.Close()
}
}
}
return nil
}
func (m *DefaultManager) register(item interface{}) error {
switch v := item.(type) {
case SourceFactory:
m.factories[v.Type()] = v
default:
return fmt.Errorf("unable to register unsupported item: %T", item)
}
return nil
}
func (m *DefaultManager) resolveSourceConfig(opt *Option) (*sourceConfig, error) {
var src sourceConfig
switch {
case len(opt.Preset) != 0 && len(opt.ConfigPath) == 0 && opt.RawConfig == nil:
preset, ok := m.Properties.Presets[opt.Preset]
if !ok {
return nil, fmt.Errorf(`invalid certificate options: preset [%s] is not found`, opt.Preset)
}
if e := json.Unmarshal(preset, &src); e != nil {
return nil, fmt.Errorf(`unable to resolve certificate source preset [%s]: %v`, opt.Preset, e)
}
case len(opt.Preset) == 0 && len(opt.ConfigPath) != 0 && opt.RawConfig == nil:
if e := m.ConfigLoaderFunc(&src, opt.ConfigPath); e != nil {
return nil, fmt.Errorf(`unable to resolve certificate source configuration: %v`, e)
}
case len(opt.Preset) == 0 && len(opt.ConfigPath) == 0 && opt.RawConfig != nil:
var rawJson []byte
switch v := opt.RawConfig.(type) {
case json.RawMessage:
rawJson = v
case []byte:
rawJson = v
case string:
rawJson = []byte(v)
default:
var e error
if rawJson, e = json.Marshal(opt.RawConfig); e != nil {
return nil, fmt.Errorf(`invalid certificate options, unsupported RawConfig type [%T]: %v`, opt.RawConfig, e)
}
}
if e := json.Unmarshal(rawJson, &src); e != nil {
return nil, fmt.Errorf(`invalid certificate options, cannot parse "raw config" as a valid JSON block: %v`, e)
}
if len(opt.Type) != 0 {
src.Type = opt.Type
}
return &src, nil
default:
return nil, fmt.Errorf(`invalid certificate options, one of "preset", "config path" or "raw config" is required. Got %v`, opt)
}
return &src, nil
}
/*************************
Helpers
*************************/
type sourceConfig struct {
Type SourceType `json:"type"`
RawConfig json.RawMessage `json:"-"`
}
func (c *sourceConfig) UnmarshalJSON(data []byte) error {
c.RawConfig = data
type cfg sourceConfig
return json.Unmarshal(data, (*cfg)(c))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package certs
func WithSourceProperties(props *SourceProperties) Options {
return func(opt *Option) {
if len(props.Preset) != 0 {
opt.Preset = props.Preset
} else {
opt.RawConfig = props.Raw
}
}
}
func WithPreset(presetName string) Options {
return func(opt *Option) {
opt.Preset = presetName
}
}
func WithConfigPath(configPath string) Options {
return func(opt *Option) {
opt.ConfigPath = configPath
}
}
func WithRawConfig(rawCfg interface{}) Options {
return func(opt *Option) {
opt.RawConfig = rawCfg
}
}
func WithType(srcType SourceType, cfg interface{}) Options {
return func(opt *Option) {
opt.Type = srcType
opt.RawConfig = cfg
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package certs
import (
"encoding/json"
)
type Properties struct {
Sources map[SourceType]json.RawMessage `json:"sources"`
Presets map[string]json.RawMessage `json:"presets"`
}
// SourceProperties convenient properties for other package to bind.
type SourceProperties struct {
// Preset is optional. When set, it should match a key in Properties.Presets
Preset string `json:"preset"`
// Type is required when Preset is not set, optional and ignored when Preset is set.
Type SourceType `json:"type"`
// Raw stores configuration as JSON.
// When Preset is set, Raw might be empty. Otherwise, Raw should at least have "type"
Raw json.RawMessage `json:"-"`
}
func (p *SourceProperties) UnmarshalJSON(data []byte) error {
p.Raw = data
type props SourceProperties
return json.Unmarshal(data, (*props)(p))
}
func NewProperties() *Properties {
return &Properties{
Sources: map[SourceType]json.RawMessage{},
Presets: map[string]json.RawMessage{},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package acmcerts
import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/cisco-open/go-lanai/pkg/certs"
certsource "github.com/cisco-open/go-lanai/pkg/certs/source"
"github.com/cisco-open/go-lanai/pkg/utils/loop"
"go.step.sm/crypto/pemutil"
"regexp"
"strings"
"sync"
"time"
)
type AcmProvider struct {
props SourceProperties
acmClient *acm.Client
cache *certsource.FileCache
cachedCertificate *tls.Certificate
lcCtx context.Context
mutex sync.RWMutex
once sync.Once
monitor *loop.Loop
monitorCancel context.CancelFunc
}
func NewAcmProvider(ctx context.Context, acm *acm.Client, p SourceProperties) certs.Source {
cache, e := certsource.NewFileCache(func(opt *certsource.FileCacheOption) {
opt.Root = p.CachePath
opt.Type = sourceType
opt.Prefix = resolveCacheKey(&p)
})
if e != nil {
logger.WithContext(ctx).Warnf("file cache for %s certificate source is not enabled: %v", sourceType, e)
}
return &AcmProvider{
props: p,
acmClient: acm,
cache: cache,
lcCtx: ctx,
monitor: loop.NewLoop(),
}
}
func (a *AcmProvider) Close() error {
return nil
}
func (a *AcmProvider) TLSConfig(ctx context.Context, _ ...certs.TLSOptions) (*tls.Config, error) {
if e := a.LazyInit(ctx); e != nil {
return nil, e
}
rootCAs, e := a.RootCAs(ctx)
if e != nil {
return nil, e
}
minVer, e := certsource.ParseTLSVersion(a.props.MinTLSVersion)
if e != nil {
return nil, e
}
//nolint:gosec // false positive - G402: TLS MinVersion too low
return &tls.Config{
GetClientCertificate: a.toGetClientCertificateFunc(),
RootCAs: rootCAs,
MinVersion: minVer,
}, nil
}
func (a *AcmProvider) Files(ctx context.Context) (*certs.CertificateFiles, error) {
if e := a.LazyInit(ctx); e != nil {
return nil, e
}
if a.cache == nil {
return nil, fmt.Errorf("unable to access certificates as local files: file cache is not enabled for source [%s]", sourceType)
}
return &certs.CertificateFiles{
RootCAPaths: []string{a.cache.ResolvePath(certsource.CachedFileKeyCA)},
CertificatePath: a.cache.ResolvePath(certsource.CachedFileKeyCertificate),
PrivateKeyPath: a.cache.ResolvePath(certsource.CachedFileKeyPrivateKey),
}, nil
}
func (a *AcmProvider) RootCAs(ctx context.Context) (*x509.CertPool, error) {
input := &acm.ExportCertificateInput{
CertificateArn: &a.props.ARN,
Passphrase: []byte(a.props.Passphrase),
}
output, err := a.acmClient.ExportCertificate(ctx, input)
if err != nil {
logger.Errorf("Could not fetch ACM certificate %s: %s", a.props.ARN, err.Error())
return nil, err
}
//Clean the returned CA (deal with bug in localStack)
cleantext := strings.Replace(*output.CertificateChain, " -----END CERTIFICATE-----", "-----END CERTIFICATE-----", -1)
pemBytes := []byte(cleantext)
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(pemBytes)
if a.cache != nil {
if err := a.cache.CachePEM(pemBytes, certsource.CachedFileKeyCA); err != nil {
logger.WithContext(ctx).Warnf(`unable to cache CA: %v`, err)
return certPool, err
}
}
return certPool, nil
}
func (a *AcmProvider) LazyInit(ctx context.Context) error {
var err error
//nolint:contextcheck // false positive, sync.Once.Do doesn't take func(context.Context)
a.once.Do(func() {
// At least get RootCA once
// TODO should we renew RootCA periodically?
_, err = a.RootCAs(ctx)
if err != nil {
logger.Errorf("Failed to get CAs from ACM: %s", err.Error())
return
}
// At least get Certificate once
var cert *tls.Certificate
cert, err = a.generateClientCertificate(ctx)
if err != nil {
logger.Errorf("Failed to get certificate from ACM: %s", err.Error())
return
}
renewIntervalFunc := certsource.RenewRepeatIntervalFunc(time.Duration(a.props.MinRenewInterval))
delay := renewIntervalFunc(cert, err)
loopCtx, cancelFunc := a.monitor.Run(a.lcCtx)
a.monitorCancel = cancelFunc
time.AfterFunc(delay, func() {
a.monitor.Repeat(a.tryRenew(loopCtx), func(opt *loop.TaskOption) {
opt.RepeatIntervalFunc = renewIntervalFunc
})
})
})
return err
}
func (a *AcmProvider) toGetClientCertificateFunc() func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return func(certificateReq *tls.CertificateRequestInfo) (*tls.Certificate, error) {
a.mutex.RLock()
defer a.mutex.RUnlock()
if a.cachedCertificate == nil {
return new(tls.Certificate), nil
}
e := certificateReq.SupportsCertificate(a.cachedCertificate)
if e != nil {
// No acceptable certificate found. Don't send a certificate. Don't need to treat as error.
// see tls package's func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error)
return new(tls.Certificate), nil //nolint:nilerr // as intended
} else {
return a.cachedCertificate, nil
}
}
}
func (a *AcmProvider) generateClientCertificate(ctx context.Context) (*tls.Certificate, error) {
input := &acm.ExportCertificateInput{
CertificateArn: &a.props.ARN,
Passphrase: []byte(a.props.Passphrase),
}
output, err := a.acmClient.ExportCertificate(ctx, input)
if err != nil {
logger.Errorf("Could not fetch ACM certificate %s: %s", a.props.ARN, err.Error())
return nil, err
}
crtPEM := []byte(*output.Certificate)
keyBlock, _ := pem.Decode([]byte(*output.PrivateKey))
unEncryptedKey, err := pemutil.DecryptPKCS8PrivateKey(keyBlock.Bytes, []byte(a.props.Passphrase))
if err != nil {
logger.Errorf("Could not decrypt pkcs8 private key: %s", err.Error())
return nil, err
}
privateKey, err := x509.ParsePKCS8PrivateKey(unEncryptedKey)
if err != nil {
logger.Errorf("Could not parse pkcs8 private key: %s", err.Error())
return nil, err
}
var privDER []byte
switch privateKey.(type) {
case *rsa.PrivateKey:
privDER = x509.MarshalPKCS1PrivateKey(privateKey.(*rsa.PrivateKey))
case *ecdsa.PrivateKey:
privDER, err = x509.MarshalECPrivateKey(privateKey.(*ecdsa.PrivateKey))
if err != nil {
logger.Errorf("Could not marshal ecdsa private key: %s", err.Error())
return nil, err
}
default:
panic("unknown key")
}
keyBlock.Bytes = privDER
keyBlock.Headers = nil
keyBytes := pem.EncodeToMemory(keyBlock)
cert, err := tls.X509KeyPair(crtPEM, keyBytes)
if err != nil {
logger.Errorf("Could not create cert from PEM: %s", err.Error())
return nil, err
}
a.mutex.Lock()
defer a.mutex.Unlock()
a.cachedCertificate = &cert
if a.cache != nil {
if err := a.cache.CacheCertificate(&cert); err != nil {
logger.WithContext(ctx).Warnf(`unable to cache certificate: %v`, err)
return &cert, err
}
}
return &cert, nil
}
func (a *AcmProvider) renewClientCertificate(ctx context.Context) error {
input := &acm.RenewCertificateInput{
CertificateArn: &a.props.ARN,
}
_, err := a.acmClient.RenewCertificate(ctx, input)
if err != nil {
logger.Errorf("Could not renew ACM certificate %s: %s", a.props.ARN, err.Error())
return err
}
return nil
}
func (a *AcmProvider) tryRenew(loopCtx context.Context) loop.TaskFunc {
return func(_ context.Context, l *loop.Loop) (ret interface{}, err error) {
//ignore error since we will just schedule another renew
err = a.renewClientCertificate(loopCtx)
if err != nil {
logger.Warn("certificate renew failed: %v", err)
}
ret, err = a.generateClientCertificate(loopCtx)
if err != nil {
logger.Warn("certificate renew failed: %v", err)
} else {
logger.Infof("certificate has been renewed")
}
return
}
}
var (
arnRegex = regexp.MustCompile(`arn:(?P<part>[^:]+):(?P<srv>[^:]+):(?P<region>[^:]+):(?P<acct>[^:]+):((?P<res_type>[^:]+)[\/:])?(?P<res_id>[^:]+)$`)
cacheKeyReplacer = strings.NewReplacer(
" ", "-",
".", "-",
"_", "-",
"@", "-at-",
)
cacheKeyCount = 0
)
func resolveCacheKey(p *SourceProperties) (key string) {
var resId, resType string
matches := arnRegex.FindStringSubmatch(p.ARN)
for i, name := range arnRegex.SubexpNames() {
if i >= len(matches) {
break
}
switch name {
case "res_id":
resId = matches[i]
case "res_type":
resType = matches[i]
}
}
cacheKeyCount++
key = fmt.Sprintf(`%s-%s-%d`, resType, resId, cacheKeyCount)
key = cacheKeyReplacer.Replace(key)
return key
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package acmcerts
import (
"encoding/json"
"fmt"
"github.com/aws/aws-sdk-go-v2/service/acm"
awsclient "github.com/cisco-open/go-lanai/pkg/aws"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
certsource "github.com/cisco-open/go-lanai/pkg/certs/source"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
var logger = log.New("Certs.ACM")
const (
sourceType = certs.SourceACM
)
var Module = &bootstrap.Module{
Name: "certs-acm",
Precedence: bootstrap.TlsConfigPrecedence,
Options: []fx.Option{
fx.Provide(FxProvider()),
},
}
func Use() {
bootstrap.Register(Module)
}
type factoryDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Props certs.Properties `optional:"true"`
AcmClient *acm.Client `optional:"true"`
AwsConfigLoader awsclient.ConfigLoader `optional:"true"`
}
func FxProvider() fx.Annotated {
return fx.Annotated{
Group: certs.FxGroup,
Target: func(di factoryDI) (certs.SourceFactory, error) {
var client *acm.Client
switch {
case di.AcmClient == nil && di.AwsConfigLoader == nil:
logger.Warnf(`AWS/ACM certificates source is not supported. Tips: Do not forget to initialize ACM client or AWS config loader.`)
return nil, nil
case di.AcmClient != nil:
client = di.AcmClient
default:
cfg, e := di.AwsConfigLoader.Load(di.AppCtx)
if e != nil {
return nil, fmt.Errorf(`unable to initialize AWS/ACM certificate source: %v`, e)
}
client = acm.NewFromConfig(cfg)
}
var rawDefaults json.RawMessage
if di.Props.Sources != nil {
rawDefaults, _ = di.Props.Sources[sourceType]
}
factory, e := certsource.NewFactory[SourceProperties](sourceType, rawDefaults, func(props SourceProperties) certs.Source {
return NewAcmProvider(di.AppCtx, client, props)
})
if e != nil {
return nil, fmt.Errorf(`unable to register certificate source type [%s]: %v`, sourceType, e)
}
return factory, nil
},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package certsource
import (
"context"
"crypto/tls"
"crypto/x509"
"dario.cat/mergo"
"encoding/json"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils/loop"
"time"
)
var tlsVersions = map[string]uint16{
"": tls.VersionTLS10, // default in golang
"tls10": tls.VersionTLS10,
"tls11": tls.VersionTLS11,
"tls12": tls.VersionTLS12,
"tls13": tls.VersionTLS13,
}
var logger = log.New("Certs.Source")
func NewFactory[PropertiesType any](typ certs.SourceType, rawDefaultConfig json.RawMessage, constructor func(props PropertiesType) certs.Source) (*GenericFactory[PropertiesType], error) {
var zero PropertiesType
defaults, e := ParseConfigWithDefaults(zero, rawDefaultConfig)
if e != nil {
return nil, fmt.Errorf(`unable to parse default certificate source configuration with type [%s]: %v`, typ, e)
}
if constructor == nil {
return nil, fmt.Errorf(`constructor of certificate source with type [%s] is missing`, typ)
}
return &GenericFactory[PropertiesType]{
SourceType: typ,
Defaults: defaults,
Constructor: constructor,
}, nil
}
type GenericFactory[PropertiesType any] struct {
SourceType certs.SourceType
Defaults PropertiesType
Constructor func(props PropertiesType) certs.Source
}
func (f *GenericFactory[T]) Type() certs.SourceType {
return f.SourceType
}
func (f *GenericFactory[T]) LoadAndInit(_ context.Context, opts ...certs.SourceOptions) (certs.Source, error) {
src := certs.SourceConfig{}
for _, fn := range opts {
fn(&src)
}
props, e := ParseConfigWithDefaults(f.Defaults, src.RawConfig)
if e != nil {
return nil, fmt.Errorf(`unable to initialize certificate source [%s]: %v`, f.Type(), e)
}
return f.Constructor(props), nil
}
func ParseConfigWithDefaults[T any](defaults T, rawConfig json.RawMessage) (T, error) {
if rawConfig == nil || len(rawConfig) == 0 {
return defaults, nil
}
var parsed T
if e := json.Unmarshal(rawConfig, &parsed); e != nil {
return defaults, e
}
if e := mergo.Merge(&defaults, &parsed, mergo.WithOverride); e != nil {
return defaults, e
}
return defaults, nil
}
func ParseTLSVersion(verStr string) (uint16, error) {
if v, ok := tlsVersions[verStr]; ok {
return v, nil
} else {
return tls.VersionTLS10, errors.New(fmt.Sprintf("unsupported tls version %s", verStr))
}
}
// RenewRepeatIntervalFunc create a loop.RepeatIntervalFunc for renewing certificate.
// The interval is set to the half way between now and cached certificate expiration.
// If "fallbackInterval" is provided, it is used for any error cases
func RenewRepeatIntervalFunc(fallbackInterval time.Duration) loop.RepeatIntervalFunc {
return func(result interface{}, err error) (ret time.Duration) {
defer func() {
logger.Debugf("certificate will renew in %v", ret)
}()
minDuration := 1 * time.Minute
if fallbackInterval != 0 {
minDuration = fallbackInterval
}
if err != nil {
return minDuration
}
cert := result.(*tls.Certificate)
if len(cert.Certificate) < 1 {
return minDuration
}
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return minDuration
}
validTo := parsedCert.NotAfter
now := time.Now()
if validTo.Before(now) {
return minDuration
}
durationRemain := validTo.Sub(now)
next := durationRemain / 2
if minDuration > next {
next = minDuration
}
return next
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package filecerts
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"github.com/cisco-open/go-lanai/pkg/certs"
certsource "github.com/cisco-open/go-lanai/pkg/certs/source"
"io"
"os"
"path/filepath"
)
type FileProvider struct {
p SourceProperties
}
func NewFileProvider(p SourceProperties) certs.Source {
return &FileProvider{
p: p,
}
}
func (f *FileProvider) TLSConfig(ctx context.Context, _ ...certs.TLSOptions) (*tls.Config, error) {
rootCAs, e := f.RootCAs(ctx)
if e != nil {
return nil, e
}
minVer, e := certsource.ParseTLSVersion(f.p.MinTLSVersion)
if e != nil {
return nil, e
}
//nolint:gosec // false positive - G402: TLS MinVersion too low
return &tls.Config{
GetClientCertificate: f.toGetClientCertificateFunc(),
RootCAs: rootCAs,
MinVersion: minVer,
}, nil
}
func (f *FileProvider) Files(_ context.Context) (*certs.CertificateFiles, error) {
return &certs.CertificateFiles{
RootCAPaths: []string{f.toAbsPath(f.p.CACertFile)},
CertificatePath: f.toAbsPath(f.p.CertFile),
PrivateKeyPath: f.toAbsPath(f.p.KeyFile),
PrivateKeyPassphrase: f.p.KeyPass,
}, nil
}
func (f *FileProvider) RootCAs(_ context.Context) (*x509.CertPool, error) {
caPem, err := os.ReadFile(f.p.CACertFile)
if err != nil {
return nil, err
}
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(caPem)
return certPool, nil
}
func (f *FileProvider) toGetClientCertificateFunc() func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return func(certificateReq *tls.CertificateRequestInfo) (*tls.Certificate, error) {
keyFile, err := os.Open(f.p.KeyFile)
if err != nil {
return nil, err
}
keyBytes, err := io.ReadAll(keyFile)
if err != nil {
return nil, err
}
if f.p.KeyPass != "" {
keyBlock, _ := pem.Decode(keyBytes)
//nolint:staticcheck
unEncryptedKey, e := x509.DecryptPEMBlock(keyBlock, []byte(f.p.KeyPass))
if e != nil {
return nil, e
}
keyBlock.Bytes = unEncryptedKey
keyBlock.Headers = nil
keyBytes = pem.EncodeToMemory(keyBlock)
}
certfile, err := os.Open(f.p.CertFile)
if err != nil {
return nil, err
}
certBytes, err := io.ReadAll(certfile)
if err != nil {
return nil, err
}
clientCert, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return nil, err
}
e := certificateReq.SupportsCertificate(&clientCert)
if e != nil {
// No acceptable certificate found. Don't send a certificate. Don't need to treat as error.
// see tls package's tls.Conn.getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error)
return new(tls.Certificate), nil //nolint:nilerr
} else {
return &clientCert, nil
}
}
}
func (f *FileProvider) toAbsPath(path string) string {
abs, e := filepath.Abs(path)
if e != nil {
return path
}
return abs
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package filecerts
import (
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/certs"
certsource "github.com/cisco-open/go-lanai/pkg/certs/source"
"go.uber.org/fx"
)
const (
sourceType = certs.SourceFile
)
type factoryDI struct {
fx.In
Props certs.Properties `optional:"true"`
}
func FxProvider() fx.Annotated {
return fx.Annotated{
Group: certs.FxGroup,
Target: func(di factoryDI) (certs.SourceFactory, error) {
var rawDefaults json.RawMessage
if di.Props.Sources != nil {
rawDefaults, _ = di.Props.Sources[sourceType]
}
factory, e := certsource.NewFactory[SourceProperties](sourceType, rawDefaults, NewFileProvider)
if e != nil {
return nil, fmt.Errorf(`unable to register certificate source type [%s]: %v`, sourceType, e)
}
return factory, nil
},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package certsource
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/utils"
"os"
"path/filepath"
)
const (
DefaultCacheRoot = `.tmp/certs`
CachedFileKeyCertificate = `cert`
CachedFileKeyPrivateKey = `key`
CachedFileKeyCA = `ca`
)
type FileCacheOptions func(opt *FileCacheOption)
type FileCacheOption struct {
Root string
Type certs.SourceType
Prefix string
}
func NewFileCache(opts ...FileCacheOptions) (*FileCache, error) {
opt := FileCacheOption{}
for _, fn := range opts {
fn(&opt)
}
if len(opt.Root) == 0 {
opt.Root = DefaultCacheRoot
}
if len(opt.Prefix) == 0 {
opt.Prefix = utils.RandomString(12)
}
dir := filepath.Clean(filepath.Join(opt.Root, string(opt.Type)))
e := os.MkdirAll(dir, 0755)
if e != nil {
return nil, e
}
return &FileCache{Dir: dir, Prefix: opt.Prefix}, nil
}
type FileCache struct {
Dir string
Prefix string
}
// CacheCertificate will write out a cert and key to files based on configured path and prefix
func (c *FileCache) CacheCertificate(cert *tls.Certificate) error {
if len(cert.Certificate) < 1 {
return fmt.Errorf("no certificates present in provided tls.Certificate")
}
pemBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Certificate[0],
}
certBytes := pem.EncodeToMemory(pemBlock)
if certBytes == nil {
return fmt.Errorf("failed to encode certificate to PEM")
}
if err := c.CachePEM(certBytes, CachedFileKeyCertificate); err != nil {
return fmt.Errorf("failed to write PEM data to file: %v", err)
}
privKeyBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey)
if err != nil {
return fmt.Errorf("unable to marshal private key: %v", err)
}
privKeyPem := &pem.Block{
Type: "PRIVATE KEY",
Bytes: privKeyBytes,
}
keyBytes := pem.EncodeToMemory(privKeyPem)
if keyBytes == nil {
return fmt.Errorf("failed to encode private key to PEM")
}
if err = c.CachePEM(keyBytes, CachedFileKeyPrivateKey); err != nil {
return fmt.Errorf("failed to write private key PEM data to file: %v", err)
}
return nil
}
// CachePEM write given data into file. The file name is determined by "key" and "suffix"
func (c *FileCache) CachePEM(pemData []byte, key string) error {
path := c.ResolvePath(key)
err := os.WriteFile(path, pemData, 0600)
if err != nil {
return fmt.Errorf("failed to write PEM data to file: %v", err)
}
return nil
}
func (c *FileCache) ResolvePath(key string) string {
filename := fmt.Sprintf(`%s-%s.pem`, c.Prefix, key)
return filepath.Clean(filepath.Join(c.Dir, filename))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaultcerts
import (
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
certsource "github.com/cisco-open/go-lanai/pkg/certs/source"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/vault"
"go.uber.org/fx"
)
var logger = log.New("Certs.Vault")
const (
sourceType = certs.SourceVault
)
var Module = &bootstrap.Module{
Name: "certs-vault",
Precedence: bootstrap.TlsConfigPrecedence,
Options: []fx.Option{
fx.Provide(FxProvider()),
},
}
func Use() {
bootstrap.Register(Module)
}
type factoryDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Props certs.Properties `optional:"true"`
VaultClient *vault.Client `optional:"true"`
}
func FxProvider() fx.Annotated {
return fx.Annotated{
Group: certs.FxGroup,
Target: func(di factoryDI) (certs.SourceFactory, error) {
if di.VaultClient == nil {
logger.Warnf(`Vault Certificates source is not supported. Tips: Do not forget to initialize vault client.`)
return nil, nil
}
var rawDefaults json.RawMessage
if di.Props.Sources != nil {
rawDefaults, _ = di.Props.Sources[sourceType]
}
factory, e := certsource.NewFactory[SourceProperties](sourceType, rawDefaults, func(props SourceProperties) certs.Source {
return NewVaultProvider(di.AppCtx, di.VaultClient, props)
})
if e != nil {
return nil, fmt.Errorf(`unable to register certificate source type [%s]: %v`, sourceType, e)
}
return factory, nil
},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaultcerts
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/cisco-open/go-lanai/pkg/certs"
certsource "github.com/cisco-open/go-lanai/pkg/certs/source"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/loop"
"github.com/cisco-open/go-lanai/pkg/vault"
"io"
"path"
"strings"
"sync"
"time"
)
type VaultProvider struct {
p SourceProperties
vc *vault.Client
cache *certsource.FileCache
once sync.Once
mutex sync.RWMutex
cachedCertificate *tls.Certificate
lcCtx context.Context
monitor *loop.Loop
monitorCancel context.CancelFunc
}
func NewVaultProvider(ctx context.Context, vc *vault.Client, p SourceProperties) certs.Source {
cache, e := certsource.NewFileCache(func(opt *certsource.FileCacheOption) {
opt.Root = p.CachePath
opt.Type = sourceType
opt.Prefix = resolveCacheKey(&p)
})
if e != nil {
logger.WithContext(ctx).Warnf("file cache for %s certificate source is not enabled: %v", sourceType, e)
}
return &VaultProvider{
p: p,
vc: vc,
lcCtx: ctx,
cache: cache,
monitor: loop.NewLoop(),
}
}
func (v *VaultProvider) TLSConfig(ctx context.Context, _ ...certs.TLSOptions) (*tls.Config, error) {
if e := v.LazyInit(ctx); e != nil {
return nil, e
}
rootCAs, e := v.RootCAs(ctx)
if e != nil {
return nil, e
}
minVer, e := certsource.ParseTLSVersion(v.p.MinTLSVersion)
if e != nil {
return nil, e
}
//nolint:gosec // false positive - G402: TLS MinVersion too low
return &tls.Config{
GetClientCertificate: v.toGetClientCertificateFunc(),
RootCAs: rootCAs,
MinVersion: minVer,
}, nil
}
func (v *VaultProvider) Files(ctx context.Context) (*certs.CertificateFiles, error) {
if e := v.LazyInit(ctx); e != nil {
return nil, e
}
if v.cache == nil {
return nil, fmt.Errorf("unable to access certificates as local files: file cache is not enabled for source [%s]", sourceType)
}
return &certs.CertificateFiles{
RootCAPaths: []string{v.cache.ResolvePath(certsource.CachedFileKeyCA)},
CertificatePath: v.cache.ResolvePath(certsource.CachedFileKeyCertificate),
PrivateKeyPath: v.cache.ResolvePath(certsource.CachedFileKeyPrivateKey),
}, nil
}
func (v *VaultProvider) RootCAs(ctx context.Context) (*x509.CertPool, error) {
resp, err := v.vc.Logical(ctx).ReadRawWithContext(ctx, path.Join(v.p.Path, "ca", "pem"))
if err != nil {
return nil, err
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
pemBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(pemBytes)
if v.cache != nil {
if err := v.cache.CachePEM(pemBytes, certsource.CachedFileKeyCA); err != nil {
logger.WithContext(ctx).Warnf(`unable to cache CA: %v`, err)
return certPool, err
}
}
return certPool, nil
}
func (v *VaultProvider) LazyInit(ctx context.Context) error {
var err error
//nolint:contextcheck // false positive, sync.Once.Do doesn't take func(context.Context)
v.once.Do(func() {
// At least get RootCA once
// TODO should we renew RootCA periodically?
_, err = v.RootCAs(ctx)
if err != nil {
logger.Errorf("Failed to get CAs from Vault: %s", err.Error())
return
}
// At least get Certificate once
var cert *tls.Certificate
cert, err = v.generateClientCertificate(ctx)
if err != nil {
logger.Errorf("Failed to get certificate from Vault: %s", err.Error())
return
}
renewIntervalFunc := certsource.RenewRepeatIntervalFunc(time.Duration(v.p.MinRenewInterval))
delay := renewIntervalFunc(cert, err)
loopCtx, cancelFunc := v.monitor.Run(v.lcCtx)
v.monitorCancel = cancelFunc
time.AfterFunc(delay, func() {
v.monitor.Repeat(v.tryRenew(loopCtx), func(opt *loop.TaskOption) {
opt.RepeatIntervalFunc = renewIntervalFunc
})
})
})
return err
}
func (v *VaultProvider) Close() error {
if v.monitorCancel != nil {
v.monitorCancel()
}
return nil
}
func (v *VaultProvider) toGetClientCertificateFunc() func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return func(certificateReq *tls.CertificateRequestInfo) (*tls.Certificate, error) {
v.mutex.RLock()
defer v.mutex.RUnlock()
if v.cachedCertificate == nil {
return new(tls.Certificate), nil
}
e := certificateReq.SupportsCertificate(v.cachedCertificate)
if e != nil {
// No acceptable certificate found. Don't send a certificate. Don't need to treat as error.
// see tls package's func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error)
return new(tls.Certificate), nil //nolint:nilerr
} else {
return v.cachedCertificate, nil
}
}
}
func (v *VaultProvider) tryRenew(loopCtx context.Context) loop.TaskFunc {
return func(_ context.Context, l *loop.Loop) (ret interface{}, err error) {
//ignore error since we will just schedule another renew
ret, err = v.generateClientCertificate(loopCtx)
if err != nil {
logger.Warn("certificate renew failed: %v", err)
} else {
logger.Debugf("certificate has been renewed")
}
return
}
}
func (v *VaultProvider) generateClientCertificate(ctx context.Context) (*tls.Certificate, error) {
fullPath := path.Join(v.p.Path, "issue", v.p.Role)
reqData := IssueCertificateRequest{
CommonName: v.p.CN,
IpSans: v.p.IpSans,
AltNames: v.p.AltNames,
TTL: v.p.TTL,
}
//nolint:contextcheck // context is passed in via Logical(ctx). false positive
secret, err := v.vc.Logical(ctx).Write(fullPath, reqData)
if err != nil {
return nil, err
}
crtPEM := []byte(secret.Data["certificate"].(string))
keyPEM := []byte(secret.Data["private_key"].(string))
cert, err := tls.X509KeyPair(crtPEM, keyPEM)
v.mutex.Lock()
defer v.mutex.Unlock()
v.cachedCertificate = &cert
if v.cache != nil {
if err := v.cache.CacheCertificate(&cert); err != nil {
logger.WithContext(ctx).Warnf(`unable to cache certificate: %v`, err)
return &cert, err
}
}
return &cert, err
}
var (
cacheKeyReplacer = strings.NewReplacer(
" ", "-",
".", "-",
"_", "-",
"@", "-at-",
"/", "-",
"\\", "-",
)
cacheKeyCount = 0
)
func resolveCacheKey(p *SourceProperties) (key string) {
cacheKeyCount++
key = fmt.Sprintf(`%s-%s-%d`, p.Role, p.CN, cacheKeyCount)
key = cacheKeyReplacer.Replace(key)
return key
}
type IssueCertificateRequest struct {
CommonName string `json:"common_name,omitempty"`
TTL utils.Duration `json:"ttl,omitempty"`
AltNames string `json:"alt_names,omitempty"`
IpSans string `json:"ip_sans,omitempty"`
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulappconfig
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
appconfiginit "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/consul"
)
type ProviderGroupOptions func(opt *ProviderGroupOption)
type ProviderGroupOption struct {
Precedence int
Prefix string
Path string
ProfileSeparator string
Connection *consul.Connection
}
// NewProviderGroup create a Consul KV store backed appconfig.ProviderGroup.
// The provider group is responsible to load application properties from Consul KV store at paths:
// <ProviderGroupOption.Prefix>/<ProviderGroupOption.Path>[<ProviderGroupOption.ProfileSeparator><any active profile>]
// e.g.
// - "userviceconfiguration/defaultapplication"
// - "userviceconfiguration/defaultapplication,prod"
// - "userviceconfiguration/my-service"
// - "userviceconfiguration/my-service,staging"
func NewProviderGroup(opts ...ProviderGroupOptions) appconfig.ProviderGroup {
opt := ProviderGroupOption{
Precedence: appconfiginit.PrecedenceExternalDefaultContext,
Prefix: DefaultConfigPathPrefix,
Path: DefaultConfigPath,
ProfileSeparator: DefaultProfileSeparator,
}
for _, fn := range opts {
fn(&opt)
}
group := appconfig.NewProfileBasedProviderGroup(opt.Precedence)
group.KeyFunc = func(profile string) string {
if profile == "" {
return fmt.Sprintf("%s/%s", opt.Prefix, opt.Path)
}
return fmt.Sprintf("%s/%s%s%s", opt.Prefix, opt.Path, opt.ProfileSeparator, profile)
}
group.CreateFunc = func(name string, order int, _ bootstrap.ApplicationConfig) appconfig.Provider {
ptr := NewConfigProvider(order, name, opt.Connection)
if ptr == nil {
return nil
}
return ptr
}
return group
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulappconfig
import (
"github.com/cisco-open/go-lanai/pkg/appconfig"
appconfiginit "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/consul"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "bootstrap endpoint",
Precedence: bootstrap.AppConfigPrecedence,
PriorityOptions: []fx.Option{
fx.Provide(
bindConsulConfigProperties,
fxNewConsulDefaultContextProviderGroup,
fxNewConsulAppContextProviderGroup,
),
},
}
type groupDI struct {
fx.In
BootstrapConfig *appconfig.BootstrapConfig
ConsulConfigProperties ConsulConfigProperties
ConsulConnection *consul.Connection `optional:"true"`
}
type appConfigProvidersOut struct {
fx.Out
ProviderGroup appconfig.ProviderGroup `group:"application-config"`
}
func withProperties(props *ConsulConfigProperties) ProviderGroupOptions {
return func(opt *ProviderGroupOption) {
opt.Prefix = props.Prefix
opt.Path = props.DefaultContext
opt.ProfileSeparator = props.ProfileSeparator
}
}
func fxNewConsulDefaultContextProviderGroup(di groupDI) appConfigProvidersOut {
if !di.ConsulConfigProperties.Enabled || di.ConsulConnection == nil {
return appConfigProvidersOut{}
}
return appConfigProvidersOut{
ProviderGroup: NewProviderGroup(withProperties(&di.ConsulConfigProperties),
func(opt *ProviderGroupOption) {
opt.Precedence = appconfiginit.PrecedenceExternalDefaultContext
opt.Connection = di.ConsulConnection
},
),
}
}
func fxNewConsulAppContextProviderGroup(di groupDI) appConfigProvidersOut {
if !di.ConsulConfigProperties.Enabled || di.ConsulConnection == nil {
return appConfigProvidersOut{}
}
appName, _ := di.BootstrapConfig.Value(bootstrap.PropertyKeyApplicationName).(string)
return appConfigProvidersOut{
ProviderGroup: NewProviderGroup(withProperties(&di.ConsulConfigProperties),
func(opt *ProviderGroupOption) {
opt.Precedence = appconfiginit.PrecedenceExternalAppContext
opt.Path = appName
opt.Connection = di.ConsulConnection
},
),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulappconfig
import "github.com/cisco-open/go-lanai/pkg/appconfig"
const (
PropertiesPrefix = "cloud.consul.config"
DefaultConfigPathPrefix = "userviceconfiguration"
DefaultConfigPath = "defaultapplication"
DefaultProfileSeparator = ","
)
type ConsulConfigProperties struct {
Enabled bool `json:"enabled"`
Prefix string `json:"prefix"`
DefaultContext string `json:"default-context"`
ProfileSeparator string `json:"profile-separator"`
}
func bindConsulConfigProperties(bootstrapConfig *appconfig.BootstrapConfig) (ConsulConfigProperties, error) {
p := ConsulConfigProperties{
Prefix: DefaultConfigPathPrefix,
DefaultContext: DefaultConfigPath,
ProfileSeparator: DefaultProfileSeparator,
Enabled: true,
}
if e := bootstrapConfig.Bind(&p, PropertiesPrefix); e != nil {
return p, e
}
return p, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulappconfig
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/consul"
"github.com/cisco-open/go-lanai/pkg/log"
)
var logger = log.New("Config.Consul")
type ConfigProvider struct {
appconfig.ProviderMeta
contextPath string
connection *consul.Connection
}
func (configProvider *ConfigProvider) Name() string {
return fmt.Sprintf("consul:%s", configProvider.contextPath)
}
func (configProvider *ConfigProvider) Load(ctx context.Context) (loadError error) {
defer func(){
if loadError != nil {
configProvider.Loaded = false
} else {
configProvider.Loaded = true
}
}()
configProvider.Settings = make(map[string]interface{})
// load keys from default context
var defaultSettings map[string]interface{}
defaultSettings, loadError = configProvider.connection.ListKeyValuePairs(
ctx,
configProvider.contextPath)
if loadError != nil {
return loadError
}
unFlattenedSettings, loadError := appconfig.UnFlatten(defaultSettings)
if loadError != nil {
return loadError
}
configProvider.Settings = unFlattenedSettings
logger.WithContext(ctx).Infof("Retrieved %d configs from consul: %s", len(defaultSettings), configProvider.contextPath)
return nil
}
func NewConfigProvider(precedence int, contextPath string, conn *consul.Connection) *ConfigProvider {
return &ConfigProvider{
ProviderMeta: appconfig.ProviderMeta{Precedence: precedence},
contextPath: contextPath, //fmt.Sprintf("%s/%s", f.sourceConfig.Prefix, f.contextPath)
connection: conn,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consul
import "github.com/hashicorp/consul/api"
// ClientAuthentication
// TODO review ClientAuthentication and KubernetesClient
type ClientAuthentication interface {
Login(client *api.Client) (token string, err error)
}
func newClientAuthentication(p *ConnectionProperties) ClientAuthentication {
var clientAuthentication ClientAuthentication
switch p.Authentication {
case Kubernetes:
clientAuthentication = TokenKubernetesAuthentication(p.Kubernetes)
case Token:
fallthrough
default:
clientAuthentication = TokenClientAuthentication(p.Token)
}
return clientAuthentication
}
type TokenClientAuthentication string
func (d TokenClientAuthentication) Login(_ *api.Client) (token string, err error) {
return string(d), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consul
import (
"fmt"
"github.com/hashicorp/consul/api"
"os"
)
type KubernetesClient struct {
config KubernetesConfig
}
func (c *KubernetesClient) Login(client *api.Client) (string, error) {
// defaults to using /var/run/secrets/kubernetes.io/serviceaccount/token if no options set
if c.config.JWTPath == "" {
c.config.JWTPath = "/var/run/secrets/kubernetes.io/serviceaccount/token"
}
jwtToken, err := readTokenFromFile(c.config.JWTPath)
if err != nil {
return "", err
}
options := &api.ACLLoginParams{
AuthMethod: c.config.Method,
BearerToken: jwtToken,
}
authToken, _, err := client.ACL().Login(options, nil)
if err != nil {
return "", err
}
logger.Info("Successfully obtained Consul token using k8s auth")
return authToken.SecretID, nil
}
func TokenKubernetesAuthentication(kubernetesConfig KubernetesConfig) *KubernetesClient {
return &KubernetesClient{
config: kubernetesConfig,
}
}
func readTokenFromFile(filepath string) (string, error) {
jwt, err := os.ReadFile(filepath)
if err != nil {
return "", fmt.Errorf("unable to read file containing service account token: %w", err)
}
return string(jwt), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consul
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/hashicorp/consul/api"
"strings"
)
var logger = log.New("Consul")
const (
PropertyPrefix = "cloud.consul"
)
var (
ErrNoInstances = errors.New("no matching service instances found")
)
type Connection struct {
client *api.Client
properties *ConnectionProperties
clientAuth ClientAuthentication
}
func (c *Connection) Client() *api.Client {
return c.client
}
func (c *Connection) ListKeyValuePairs(ctx context.Context, path string) (results map[string]interface{}, err error) {
queryOptions := &api.QueryOptions{}
entries, _, err := c.client.KV().List(path, queryOptions.WithContext(ctx))
if err != nil {
return nil, err
}
prefix := path + "/"
results = make(map[string]interface{})
for _, entry := range entries {
if !strings.HasPrefix(entry.Key, prefix) {
continue
}
propName := strings.TrimPrefix(entry.Key, prefix)
if len(propName) > 0 {
strVal := string(entry.Value)
results[propName] = utils.ParseString(strVal)
}
}
if err != nil {
return nil, err
}
return results, nil
}
func (c *Connection) GetKeyValue(ctx context.Context, path string) (value []byte, err error) {
queryOptions := &api.QueryOptions{}
data, _, err := c.client.KV().Get(path, queryOptions.WithContext(ctx))
if err != nil {
return nil, err
} else if data == nil {
value = nil
} else {
value = data.Value
}
logger.WithContext(ctx).Debugf("Retrieved kv pair from consul %q: %s", c.host(), path)
return
}
func (c *Connection) SetKeyValue(ctx context.Context, path string, value []byte) error {
kvPair := &api.KVPair{
Key: path,
Value: value,
}
writeOptions := &api.WriteOptions{}
_, err := c.client.KV().Put(kvPair, writeOptions.WithContext(ctx))
if err != nil {
return err
}
logger.WithContext(ctx).Debugf("Stored kv pair to consul %q: %s", c.host(), path)
return nil
}
func (c *Connection) host() string {
return fmt.Sprintf(`%s:%d`, c.properties.Host, c.properties.Port)
}
type Options func(cfg *ClientConfig) error
type ClientConfig struct {
*api.Config
Properties *ConnectionProperties
ClientAuth ClientAuthentication
}
func WithProperties(p ConnectionProperties) Options {
return func(cfg *ClientConfig) error {
cfg.Properties = &p
cfg.ClientAuth = newClientAuthentication(&p)
cfg.Address = p.Address()
cfg.Scheme = p.Scheme
if cfg.Scheme == "https" {
cfg.TLSConfig.CAFile = p.SSL.CaCert
cfg.TLSConfig.CertFile = p.SSL.ClientCert
cfg.TLSConfig.KeyFile = p.SSL.ClientKey
cfg.TLSConfig.InsecureSkipVerify = p.SSL.Insecure
}
return nil
}
}
func New(opts ...Options) (*Connection, error) {
cfg := ClientConfig{
Config: api.DefaultConfig(),
ClientAuth: TokenClientAuthentication(""),
}
for _, fn := range opts {
if e := fn(&cfg); e != nil {
return nil, e
}
}
return newConn(&cfg)
}
func newConn(cfg *ClientConfig) (*Connection, error) {
client, err := api.NewClient(cfg.Config)
if err != nil {
return nil, err
}
if cfg.ClientAuth != nil {
token, err := cfg.ClientAuth.Login(client)
if err != nil {
return nil, err
}
cfg.Token = token
}
client, err = api.NewClient(cfg.Config)
if err != nil {
return nil, err
}
return &Connection{
client: client,
properties: cfg.Properties,
clientAuth: cfg.ClientAuth,
}, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consul
import "strings"
const (
Token = AuthMethod("token")
Kubernetes = AuthMethod("kubernetes")
)
var refreshable = map[AuthMethod]struct{}{
Kubernetes: {},
}
type AuthMethod string
// UnmarshalText encoding.TextUnmarshaler
func (a *AuthMethod) UnmarshalText(data []byte) error {
*a = AuthMethod(strings.ToLower(string(data)))
return nil
}
func (a AuthMethod) isRefreshable() bool {
_, ok := refreshable[a]
return ok
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulhealth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"github.com/cisco-open/go-lanai/pkg/consul"
"go.uber.org/fx"
)
type HealthIndicator struct {
conn *consul.Connection
}
type HealthRegDI struct {
fx.In
HealthRegistrar health.Registrar `optional:"true"`
ConsulClient *consul.Connection `optional:"true"`
}
func Register(di HealthRegDI) error {
if di.HealthRegistrar == nil || di.ConsulClient == nil {
return nil
}
return di.HealthRegistrar.Register(New(di.ConsulClient))
}
func New(conn *consul.Connection) *HealthIndicator {
return &HealthIndicator{
conn: conn,
}
}
func (i *HealthIndicator) Name() string {
return "consul"
}
func (i *HealthIndicator) Health(c context.Context, options health.Options) health.Health {
if _, e := i.conn.Client().Status().Leader(); e != nil {
return health.NewDetailedHealth(health.StatusDown, "consul leader status failed", nil)
} else {
return health.NewDetailedHealth(health.StatusUp, "consul leader status succeeded", nil)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consul
import (
"embed"
consulappconfig "github.com/cisco-open/go-lanai/pkg/consul/appconfig"
"github.com/cisco-open/go-lanai/pkg/appconfig"
appconfigInit "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/consul"
consulhealth "github.com/cisco-open/go-lanai/pkg/consul/health"
"github.com/pkg/errors"
"go.uber.org/fx"
)
//go:embed defaults-consul.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "consul",
Precedence: bootstrap.ConsulPrecedence,
PriorityOptions: []fx.Option{
fx.Provide(BindConnectionProperties),
fx.Provide(ProvideDefaultClient),
},
Options: []fx.Option{
appconfigInit.FxEmbeddedDefaults(defaultConfigFS),
fx.Invoke(consulhealth.Register),
},
Modules: []*bootstrap.Module{
consulappconfig.Module,
},
}
func Use() {
bootstrap.Register(Module)
}
func BindConnectionProperties(bootstrapConfig *appconfig.BootstrapConfig) consul.ConnectionProperties {
c := consul.ConnectionProperties{}
if e := bootstrapConfig.Bind(&c, consul.PropertyPrefix); e != nil {
panic(errors.Wrap(e, "failed to bind consul's ConnectionProperties"))
}
return c
}
type clientDI struct {
fx.In
Props consul.ConnectionProperties
Customizers []consul.Options `group:"consul"`
}
func ProvideDefaultClient(di clientDI) (*consul.Connection, error) {
opts := append([]consul.Options{consul.WithProperties(di.Props)}, di.Customizers...)
return consul.New(opts...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consul
import "fmt"
type ConnectionProperties struct {
Host string `json:"host"`
Port int `json:"port"`
Scheme string `json:"scheme"`
SSL SSLProperties `json:"ssl"`
Authentication AuthMethod `json:"authentication"`
Kubernetes KubernetesConfig `json:"kubernetes"`
Token string `json:"token"`
}
func (c ConnectionProperties) Address() string {
return fmt.Sprintf("%s:%d", c.Host, c.Port)
}
type SSLProperties struct {
CaCert string `json:"ca-cert"`
ClientCert string `json:"client-cert"`
ClientKey string `json:"client-key"`
Insecure bool `json:"insecure"`
}
type KubernetesConfig struct {
JWTPath string `json:"jwt-path"`
Method string `json:"method"`
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"gorm.io/gorm"
"sort"
)
const (
gormPluginErrorTranslation = gormCallbackPrefix + "error:translate"
)
// errorHandlingGormConfigurer implement a GormConfigurer that installs errorTranslatorGormPlugin for error handling/transformation
// see errorTranslatorGormPlugin for more details
type errorHandlingGormConfigurer []ErrorTranslator
func NewGormErrorHandlingConfigurer(translators...ErrorTranslator) GormConfigurer {
return errorHandlingGormConfigurer(translators)
}
func (c errorHandlingGormConfigurer) Order() int {
return 0
}
func (c errorHandlingGormConfigurer) Configure(config *gorm.Config) {
if config.Plugins == nil {
config.Plugins = map[string]gorm.Plugin{}
}
config.Plugins[gormPluginErrorTranslation] = newErrorHandlingGormPlugin(c...)
}
// errorTranslatorGormPlugin installs gorm callbacks of all operations and give ErrorTranslator a chance to handle errors
// before *gorm.DB operations return
type errorTranslatorGormPlugin []ErrorTranslator
func newErrorHandlingGormPlugin(translators ...ErrorTranslator) errorTranslatorGormPlugin {
sort.SliceStable(translators, func(i, j int) bool {
return order.OrderedFirstCompare(translators[i], translators[j])
})
return translators
}
func (errorTranslatorGormPlugin) Name() string {
return gormPluginErrorTranslation
}
func (p errorTranslatorGormPlugin) Initialize(db *gorm.DB) error {
errs := map[string]error{}
cbName := gormPluginErrorTranslation
errs["Create"] = db.Callback().Create().After("*").Register(cbName, p.translateErrorCallback())
errs["Query"] = db.Callback().Query().After("*").Register(cbName, p.translateErrorCallback())
errs["Update"] = db.Callback().Update().After("*").Register(cbName, p.translateErrorCallback())
errs["Delete"] = db.Callback().Delete().After("*").Register(cbName, p.translateErrorCallback())
errs["Raw"] = db.Callback().Raw().After("*").Register(cbName, p.translateErrorCallback())
errs["Row"] = db.Callback().Row().After("*").Register(cbName, p.translateErrorCallback())
for k, e := range errs {
if e != nil {
return fmt.Errorf("unable to install error transformation callbacks for %s: %v", k, e)
}
}
return nil
}
func (p errorTranslatorGormPlugin) translateErrorCallback() func(*gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil {
return
}
for _, translator := range p {
switch t := translator.(type) {
case GormErrorTranslator:
db.Error = t.TranslateWithDB(db)
default:
db.Error = translator.Translate(db.Statement.Context, db.Error)
}
if db.Error == nil {
return
}
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"context"
"errors"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
"net/http"
)
// WebDataErrorTranslator implements web.ErrorTranslator
//
//goland:noinspection GoNameStartsWithPackageName
type WebDataErrorTranslator struct{}
//goland:noinspection GoNameStartsWithPackageName
func NewWebDataErrorTranslator() ErrorTranslator {
return WebDataErrorTranslator{}
}
func (WebDataErrorTranslator) Order() int {
return ErrorTranslatorOrderData
}
func (t WebDataErrorTranslator) Translate(ctx context.Context, err error) error {
//nolint:errorlint
if _, ok := err.(errorutils.ErrorCoder); !ok || !errors.Is(err, ErrorCategoryData) {
return err
}
switch {
case errors.Is(err, ErrorRecordNotFound), errors.Is(err, ErrorIncorrectRecordCount):
return t.errorWithStatusCode(ctx, err, http.StatusNotFound)
case errors.Is(err, ErrorSubTypeDataIntegrity):
return t.errorWithStatusCode(ctx, err, http.StatusConflict)
case errors.Is(err, ErrorSubTypeQuery):
return t.errorWithStatusCode(ctx, err, http.StatusBadRequest)
case errors.Is(err, ErrorSubTypeTimeout):
return t.errorWithStatusCode(ctx, err, http.StatusRequestTimeout)
case errors.Is(err, ErrorTypeTransient):
return t.errorWithStatusCode(ctx, err, http.StatusServiceUnavailable)
default:
return t.errorWithStatusCode(ctx, err, http.StatusInternalServerError)
}
}
//nolint:errorlint
func (t WebDataErrorTranslator) errorWithStatusCode(_ context.Context, err error, sc int) error {
return NewErrorWithStatusCode(err.(DataError), sc)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"errors"
. "github.com/cisco-open/go-lanai/pkg/utils/error"
"gorm.io/gorm"
)
const (
// Reserved data reserved reserved error range
Reserved = 0xdb << ReservedOffset
)
// All "Type" values are used as mask
const (
_ = iota
ErrorTypeCodeInternal = Reserved + iota<<ErrorTypeOffset
ErrorTypeCodeNonTransient
ErrorTypeCodeTransient
ErrorTypeCodeUncategorizedServerSide
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeInternal
const (
_ = iota
ErrorSubTypeCodeInternal = ErrorTypeCodeInternal + iota<<ErrorSubTypeOffset
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeNonTransient
const (
_ = iota
ErrorSubTypeCodeQuery = ErrorTypeCodeNonTransient + iota<<ErrorSubTypeOffset
ErrorSubTypeCodeApi
ErrorSubTypeCodeDataRetrieval
ErrorSubTypeCodeDataIntegrity
ErrorSubTypeCodeTransaction
ErrorSubTypeCodeSecurity
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeTransient
const (
_ = iota
ErrorSubTypeCodeConcurrency = ErrorTypeCodeTransient + iota<<ErrorSubTypeOffset
ErrorSubTypeCodeTimeout
ErrorSubTypeCodeReplica
)
// ErrorSubTypeCodeInternal
const (
_ = iota
ErrorCodeInternal = ErrorSubTypeCodeInternal + iota
)
// ErrorSubTypeCodeQuery
const (
_ = iota
ErrorCodeInvalidSQL = ErrorSubTypeCodeQuery + iota
ErrorCodeInvalidPagination
ErrorCodeInsufficientPrivilege
)
// ErrorSubTypeCodeApi
const (
_ = iota
ErrorCodeInvalidApiUsage = ErrorSubTypeCodeApi + iota
ErrorCodeUnsupportedCondition
ErrorCodeUnsupportedOptions
ErrorCodeInvalidCrudModel
ErrorCodeInvalidCrudParam
)
// ErrorSubTypeCodeDataRetrieval
const (
_ = iota
ErrorCodeRecordNotFound = ErrorSubTypeCodeDataRetrieval + iota
ErrorCodeOrmMapping
ErrorCodeIncorrectRecordCount
)
// ErrorSubTypeCodeDataIntegrity
const (
_ = iota
ErrorCodeDuplicateKey = ErrorSubTypeCodeDataIntegrity + iota
ErrorCodeConstraintViolation
ErrorCodeInvalidSchema
)
// ErrorSubTypeCodeTransaction
const (
_ = iota
ErrorCodeInvalidTransaction = ErrorSubTypeCodeTransaction + iota
)
// ErrorSubTypeCodeSecurity
const (
_ = iota
ErrorCodeAuthenticationFailed = ErrorSubTypeCodeSecurity + iota
ErrorCodeFieldOperationDenied
)
// ErrorSubTypeCodeConcurrency
const (
_ = iota
ErrorCodePessimisticLocking = ErrorSubTypeCodeConcurrency + iota
ErrorCodeOptimisticLocking
)
// ErrorSubTypeCodeTimeout
const (
_ = iota
ErrorCodeQueryTimeout = ErrorSubTypeCodeTimeout + iota
)
// ErrorSubTypeCodeApi
const (
_ = iota
ErrorCodeReplicaUnavailable = ErrorSubTypeCodeReplica + iota
)
// ErrorTypes, can be used in errors.Is
var (
ErrorCategoryData = NewErrorCategory(Reserved, errors.New("error type: data"))
ErrorTypeInternal = NewErrorType(ErrorTypeCodeInternal, errors.New("error type: internal"))
ErrorTypeNonTransient = NewErrorType(ErrorTypeCodeNonTransient, errors.New("error type: non-transient"))
ErrorTypeTransient = NewErrorType(ErrorTypeCodeTransient, errors.New("error type: transient"))
ErrorTypeUnCategorizedServerSide = NewErrorType(ErrorTypeCodeUncategorizedServerSide, errors.New("error type: uncategorized server-side"))
ErrorSubTypeInternalError = NewErrorSubType(ErrorSubTypeCodeInternal, errors.New("error sub-type: internal"))
ErrorSubTypeQuery = NewErrorSubType(ErrorSubTypeCodeQuery, errors.New("error sub-type: query"))
ErrorSubTypeApi = NewErrorSubType(ErrorSubTypeCodeApi, errors.New("error sub-type: api"))
ErrorSubTypeDataRetrieval = NewErrorSubType(ErrorSubTypeCodeDataRetrieval, errors.New("error sub-type: retrieval"))
ErrorSubTypeDataIntegrity = NewErrorSubType(ErrorSubTypeCodeDataIntegrity, errors.New("error sub-type: integrity"))
ErrorSubTypeTransaction = NewErrorSubType(ErrorSubTypeCodeTransaction, errors.New("error sub-type: transaction"))
ErrorSubTypeSecurity = NewErrorSubType(ErrorSubTypeCodeSecurity, errors.New("error sub-type: security"))
ErrorSubTypeConcurrency = NewErrorSubType(ErrorSubTypeCodeConcurrency, errors.New("error sub-type: concurency"))
ErrorSubTypeTimeout = NewErrorSubType(ErrorSubTypeCodeTimeout, errors.New("error sub-type: timeout"))
ErrorSubTypeReplica = NewErrorSubType(ErrorSubTypeCodeReplica, errors.New("error sub-type: replica"))
)
// Concrete error, can be used in errors.Is for exact match
var (
ErrorSortByUnknownColumn = NewDataError(ErrorCodeOrmMapping, "SortBy column unknown")
ErrorRecordNotFound = NewDataError(ErrorCodeRecordNotFound, gorm.ErrRecordNotFound)
ErrorIncorrectRecordCount = NewDataError(ErrorCodeIncorrectRecordCount, "incorrect record count")
ErrorDuplicateKey = NewDataError(ErrorCodeDuplicateKey, "duplicate key")
ErrorInsufficientPrivilege = NewDataError(ErrorCodeInsufficientPrivilege, "insufficient privilege")
)
func init() {
Reserve(ErrorCategoryData)
}
//goland:noinspection GoNameStartsWithPackageName
type DataError interface {
error
NestedError
Details() interface{}
WithDetails(interface{}) DataError
WithMessage(msg string, args ...interface{}) DataError
WithCause(cause error, msg string, args ...interface{}) DataError
}
// dataError implements DataError and errorutils.Unwrapper
//
//goland:noinspection GoNameStartsWithPackageName
type dataError struct {
*CodedError
details interface{}
}
func (e dataError) Details() interface{} {
return e.details
}
func (e dataError) WithDetails(details interface{}) DataError {
return dataError{
CodedError: e.CodedError,
details: details,
}
}
func (e dataError) WithMessage(msg string, args ...interface{}) DataError {
return dataError{
CodedError: e.CodedError.WithMessage(msg, args...),
details: e.details,
}
}
func (e dataError) WithCause(cause error, msg string, args ...interface{}) DataError {
return dataError{
CodedError: e.CodedError.WithCause(cause, msg, args...),
details: e.details,
}
}
func (e dataError) Unwrap() error {
cause := e.Cause()
//nolint:errorlint
switch cause.(type) {
case NestedError:
return e.RootCause()
default:
return cause
}
}
// webDataError also implements web.StatusCoder
//
//goland:noinspection GoNameStartsWithPackageName
type webDataError struct {
dataError
SC int
}
func (e webDataError) StatusCode() int {
return e.SC
}
func (e webDataError) WithStatusCode(sc int) DataError {
return webDataError{dataError: e.dataError, SC: sc}
}
func (e webDataError) WithMessage(msg string, args ...interface{}) DataError {
return webDataError{dataError: e.dataError.WithMessage(msg, args...).(dataError), SC: e.SC}
}
/**********************
Constructors
**********************/
func NewDataError(code int64, e interface{}, causes ...interface{}) DataError {
return &dataError{
CodedError: NewCodedError(code, e, causes...),
}
}
func NewErrorWithStatusCode(err error, sc int) DataError {
//nolint:errorlint // we don't consider wrapped error here.
switch e := err.(type) {
case dataError:
return &webDataError{dataError: e, SC: sc}
case CodedError:
return &webDataError{dataError: dataError{CodedError: &e}, SC: sc}
case *CodedError:
return &webDataError{dataError: dataError{CodedError: e}, SC: sc}
case ErrorCoder:
return &webDataError{dataError: *NewDataError(e.Code(), e).(*dataError), SC: sc}
default:
return &webDataError{dataError: *NewDataError(ErrorSubTypeCodeInternal, e).(*dataError), SC: sc}
}
}
func NewInternalError(value interface{}, causes ...interface{}) DataError {
return NewDataError(ErrorSubTypeCodeInternal, value, causes...)
}
func NewRecordNotFoundError(value interface{}, causes ...interface{}) DataError {
return NewDataError(ErrorCodeRecordNotFound, value, causes...)
}
func NewConstraintViolationError(value interface{}, causes ...interface{}) DataError {
return NewDataError(ErrorCodeConstraintViolation, value, causes...)
}
func NewDuplicateKeyError(value interface{}, causes ...interface{}) DataError {
return NewDataError(ErrorCodeDuplicateKey, value, causes...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
"time"
)
const (
GormCallbackBeforeCreate = "gorm:before_create"
GormCallbackAfterCreate = "gorm:after_create"
GormCallbackBeforeQuery = "gorm:query"
GormCallbackAfterQuery = "gorm:after_query"
GormCallbackBeforeUpdate = "gorm:before_update"
GormCallbackAfterUpdate = "gorm:after_update"
GormCallbackBeforeDelete = "gorm:before_delete"
GormCallbackAfterDelete = "gorm:after_delete"
GormCallbackBeforeRow = "gorm:row"
GormCallbackAfterRow = "gorm:row"
GormCallbackBeforeRaw = "gorm:raw"
GormCallbackAfterRaw = "gorm:raw"
)
const (
gormCallbackPrefix = "lanai:"
)
type GormErrorTranslator interface {
TranslateWithDB(db *gorm.DB) error
}
type GormConfigurer interface {
Configure(config *gorm.Config)
}
type GormOptions func(cfg *GormConfig)
type GormConfig struct {
Dialector gorm.Dialector
LogLevel log.LoggingLevel
LogSlowQueryThreshold time.Duration
Configurers []GormConfigurer
}
func NewGorm(opts ...GormOptions) *gorm.DB {
cfg := GormConfig{
LogSlowQueryThreshold: 15 * time.Second,
}
for _, fn := range opts {
fn(&cfg)
}
level := gormlogger.Warn
switch cfg.LogLevel {
case log.LevelOff:
level = gormlogger.Silent
case log.LevelDebug, log.LevelInfo:
level = gormlogger.Info
case log.LevelWarn:
level = gormlogger.Warn
case log.LevelError:
level = gormlogger.Error
}
config := gorm.Config{
Logger: newGormLogger(level, cfg.LogSlowQueryThreshold),
}
// gave configurer an chance
order.SortStable(cfg.Configurers, order.OrderedFirstCompare)
for _, c := range cfg.Configurers {
c.Configure(&config)
}
db, e := gorm.Open(cfg.Dialector, &config)
if e != nil {
panic(e)
}
return db
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"context"
"errors"
"gorm.io/gorm"
)
var (
GormErrorMapping = map[error]DataError{
gorm.ErrRecordNotFound: NewDataError(ErrorCodeRecordNotFound, gorm.ErrRecordNotFound),
gorm.ErrInvalidTransaction: NewDataError(ErrorCodeInvalidTransaction, gorm.ErrInvalidTransaction),
gorm.ErrNotImplemented: NewDataError(ErrorCodeInvalidApiUsage, gorm.ErrNotImplemented),
gorm.ErrMissingWhereClause: NewDataError(ErrorCodeInvalidSQL, gorm.ErrMissingWhereClause),
gorm.ErrUnsupportedRelation: NewDataError(ErrorCodeInvalidSchema, gorm.ErrUnsupportedRelation),
gorm.ErrPrimaryKeyRequired: NewDataError(ErrorCodeInvalidSQL, gorm.ErrPrimaryKeyRequired),
gorm.ErrModelValueRequired: NewDataError(ErrorCodeOrmMapping, gorm.ErrModelValueRequired),
gorm.ErrInvalidData: NewDataError(ErrorCodeOrmMapping, gorm.ErrInvalidData),
gorm.ErrUnsupportedDriver: NewDataError(ErrorCodeInternal, gorm.ErrUnsupportedDriver),
gorm.ErrRegistered: NewDataError(ErrorCodeInternal, gorm.ErrRegistered), // TODO ??
gorm.ErrInvalidField: NewDataError(ErrorCodeInvalidSQL, gorm.ErrInvalidField),
gorm.ErrEmptySlice: NewDataError(ErrorCodeIncorrectRecordCount, gorm.ErrEmptySlice),
gorm.ErrDryRunModeUnsupported: NewDataError(ErrorCodeInvalidApiUsage, gorm.ErrDryRunModeUnsupported),
gorm.ErrInvalidDB: NewDataError(ErrorCodeInvalidApiUsage, gorm.ErrInvalidDB),
gorm.ErrInvalidValue: NewDataError(ErrorCodeInvalidSQL, gorm.ErrInvalidValue),
gorm.ErrInvalidValueOfLength: NewDataError(ErrorCodeInvalidSQL, gorm.ErrInvalidValueOfLength),
}
)
type DefaultGormErrorTranslator struct {
ErrorTranslator
}
func (t DefaultGormErrorTranslator) TranslateWithDB(db *gorm.DB) error {
if db.Error == nil {
return nil
}
err := t.Translate(db.Statement.Context, db.Error)
//nolint:errorlint
switch e := err.(type) {
case DataError:
switch {
case db.Statement != nil:
return e.WithDetails(db.Statement)
}
}
return err
}
// gormErrorTranslator implements GormErrorTranslator and ErrorTranslator
type gormErrorTranslator struct{}
func NewGormErrorTranslator() ErrorTranslator {
return DefaultGormErrorTranslator{
ErrorTranslator: gormErrorTranslator{},
}
}
func (gormErrorTranslator) Order() int {
return ErrorTranslatorOrderGorm
}
func (gormErrorTranslator) Translate(_ context.Context, err error) error {
for k, v := range GormErrorMapping {
if errors.Is(err, k) {
return v
}
}
return err
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log"
gormlogger "gorm.io/gorm/logger"
"time"
)
const (
logKeyDb = "db"
)
type dbLogEntry struct {
Type string `json:"type"`
TimeElapsed time.Duration `json:"duration"`
Error string `json:"error"`
Rows int `json:"rows"`
Query string `json:"query"`
}
type GormLogger struct {
level gormlogger.LogLevel
slowThreshold time.Duration
colored bool
}
func newGormLogger(level gormlogger.LogLevel, slowThreshold time.Duration) *GormLogger {
return &GormLogger{
level: level,
slowThreshold: slowThreshold,
colored: log.IsTerminal(logger),
}
}
func (l GormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
return &GormLogger{
level: level,
slowThreshold: l.slowThreshold,
colored: l.colored,
}
}
func (l GormLogger) Info(ctx context.Context, s string, i ...interface{}) {
if l.level >= gormlogger.Info {
logger.WithContext(ctx).Infof(s, i...)
}
}
func (l GormLogger) Warn(ctx context.Context, s string, i ...interface{}) {
if l.level >= gormlogger.Warn {
logger.WithContext(ctx).Warnf(s, i...)
}
}
func (l GormLogger) Error(ctx context.Context, s string, i ...interface{}) {
if l.level >= gormlogger.Error {
logger.WithContext(ctx).Errorf(s, i...)
}
}
func (l GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.level <= gormlogger.Silent {
return
}
elapsed := time.Since(begin)
var kvs *dbLogEntry
var title string
switch {
case err != nil && l.level >= gormlogger.Error:
sql, rows := fc()
kvs = &dbLogEntry{
Type: "error",
TimeElapsed: elapsed.Truncate(time.Microsecond),
Error: err.Error(),
Rows: int(rows),
Query: sql,
}
title = "Error"
case elapsed > l.slowThreshold && l.slowThreshold != 0 && l.level >= gormlogger.Warn:
sql, rows := fc()
kvs = &dbLogEntry{
Type: "slow",
TimeElapsed: elapsed.Truncate(time.Microsecond),
Rows: int(rows),
Query: sql,
}
title = "Slow"
case l.level == gormlogger.Info:
sql, rows := fc()
kvs = &dbLogEntry{
Type: "sql",
TimeElapsed: elapsed.Truncate(time.Microsecond),
Rows: int(rows),
Query: sql,
}
title = "SQL"
default:
return
}
title = "DB " + title
if l.colored {
title = gormlogger.Cyan + title + gormlogger.Reset
}
logger.WithContext(ctx).WithKV(logKeyDb, kvs).
Debugf("[%s] %10v | %d Rows | %s | %s", title, kvs.TimeElapsed, kvs.Rows, kvs.Error, kvs.Query)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"gorm.io/gorm"
)
const (
gormPluginTracing = gormCallbackPrefix + "tracing"
tracingOpName = "db"
)
type gormConfigurer struct {
tracer opentracing.Tracer
}
func NewGormTracingConfigurer(tracer opentracing.Tracer) GormConfigurer {
return &gormConfigurer{
tracer: tracer,
}
}
func (c gormConfigurer) Order() int {
return order.Highest + 1
}
func (c gormConfigurer) Configure(config *gorm.Config) {
if config.Plugins == nil {
config.Plugins = map[string]gorm.Plugin{}
}
config.Plugins[gormPluginTracing] = &gormPlugin{
tracer: c.tracer,
}
}
type gormCallbackFunc func(*gorm.DB)
type gormPlugin struct {
tracer opentracing.Tracer
}
// Name implements gorm.Plugin
func (p gormPlugin) Name() string {
return "tracing"
}
// Initialize implements gorm.Plugin. This function register tracing related callbacks
// Default callbacks can be found at github.com/go-gorm/gorm/callbacks/callbacks.go
func (p gormPlugin) Initialize(db *gorm.DB) error {
_ = db.Callback().Create().Before(GormCallbackBeforeCreate).
Register(p.cbBeforeName("create"), p.makeBeforeCallback("create"))
_ = db.Callback().Create().After(GormCallbackAfterCreate).
Register(p.cbAfterName("create"), p.makeAfterCallback("create"))
_ = db.Callback().Query().Before(GormCallbackBeforeQuery).
Register(p.cbBeforeName("query"), p.makeBeforeCallback("select"))
_ = db.Callback().Query().After(GormCallbackAfterQuery).
Register(p.cbAfterName("query"), p.makeAfterCallback("select"))
_ = db.Callback().Update().Before(GormCallbackBeforeUpdate).
Register(p.cbBeforeName("update"), p.makeBeforeCallback("update"))
_ = db.Callback().Update().After(GormCallbackAfterUpdate).
Register(p.cbAfterName("update"), p.makeAfterCallback("update"))
_ = db.Callback().Delete().Before(GormCallbackBeforeDelete).
Register(p.cbBeforeName("delete"), p.makeBeforeCallback("delete"))
_ = db.Callback().Delete().After(GormCallbackAfterDelete).
Register(p.cbAfterName("delete"), p.makeAfterCallback("delete"))
_ = db.Callback().Row().Before(GormCallbackBeforeRow).
Register(p.cbBeforeName("row"), p.makeBeforeCallback("row"))
_ = db.Callback().Row().After(GormCallbackAfterRow).
Register(p.cbAfterName("row"), p.makeAfterCallback("row"))
_ = db.Callback().Raw().Before(GormCallbackBeforeRaw).
Register(p.cbBeforeName("raw"), p.makeBeforeCallback("sql"))
_ = db.Callback().Raw().After(GormCallbackAfterRaw).
Register(p.cbAfterName("raw"), p.makeAfterCallback("sql"))
return nil
}
func (p gormPlugin) makeBeforeCallback(opName string) gormCallbackFunc {
return func(db *gorm.DB) {
ctx := db.Statement.Context
name := tracingOpName + " " + opName
table := db.Statement.Table
if db.Statement.TableExpr != nil {
table = db.Statement.TableExpr.SQL
}
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCClientEnum),
tracing.SpanTag("table", table),
}
db.Statement.Context = tracing.WithTracer(p.tracer).
WithOpName(name).
WithOptions(opts...).
DescendantOrNoSpan(ctx)
}
}
func (p gormPlugin) makeAfterCallback(_ string) gormCallbackFunc {
return func(db *gorm.DB) {
ctx := db.Statement.Context
op := tracing.WithTracer(p.tracer)
if db.Error != nil {
op = op.WithOptions(tracing.SpanTag("err", db.Error))
} else {
op = op.WithOptions(tracing.SpanTag("rows", db.RowsAffected))
}
db.Statement.Context = op.FinishAndRewind(ctx)
}
}
func (p gormPlugin) cbBeforeName(name string) string {
return gormCallbackPrefix + "before_" + name
}
func (p gormPlugin) cbAfterName(name string) string {
return gormCallbackPrefix + "after_" + name
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"gorm.io/gorm"
)
// DbHealthIndicator
// Note: we currently only support one database
type DbHealthIndicator struct {
db *gorm.DB
}
func (i *DbHealthIndicator) Name() string {
return "database"
}
func (i *DbHealthIndicator) Health(c context.Context, options health.Options) health.Health {
if sqldb, e := i.db.DB(); e != nil {
return health.NewDetailedHealth(health.StatusUnknown, "database ping is not available", nil)
} else {
if e := sqldb.Ping(); e != nil {
return health.NewDetailedHealth(health.StatusDown, "database ping failed", nil)
} else {
return health.NewDetailedHealth(health.StatusUp, "database ping succeeded", nil)
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/data/repo"
"github.com/cisco-open/go-lanai/pkg/data/tx"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
"reflect"
)
//var logger = log.New("Data")
var Module = &bootstrap.Module{
Name: "DB",
Precedence: bootstrap.DatabasePrecedence,
Options: []fx.Option{
fx.Provide(
transactionMaxRetry(),
),
web.FxErrorTranslatorProviders(
webErrTranslatorProvider(data.NewWebDataErrorTranslator),
),
},
Modules: []*bootstrap.Module{
data.Module, tx.Module, repo.Module,
},
}
func Use() {
bootstrap.Register(Module)
}
/**************************
Provider
***************************/
func webErrTranslatorProvider(provider interface{}) func() web.ErrorTranslator {
return func() web.ErrorTranslator {
fnv := reflect.ValueOf(provider)
ret := fnv.Call(nil)
return ret[0].Interface().(web.ErrorTranslator)
}
}
func transactionMaxRetry() fx.Annotated {
return fx.Annotated{
Group: tx.FxTransactionExecuterOption,
Target: func(properties data.DataProperties) tx.TransactionExecuterOption {
return tx.MaxRetries(properties.Transaction.MaxRetry, 0)
},
}
}
/**************************
Initialize
***************************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/opentracing/opentracing-go"
"go.uber.org/fx"
"gorm.io/gorm"
"time"
)
var logger = log.New("Data")
var Module = &bootstrap.Module{
Name: "DB",
Precedence: bootstrap.DatabasePrecedence,
Options: []fx.Option{
fx.Provide(
BindDataProperties,
provideGorm,
gormErrTranslatorProvider(),
),
fx.Invoke(registerHealth),
},
}
/**************************
Provider
***************************/
type gormInitDI struct {
fx.In
Dialector gorm.Dialector
Properties DataProperties
Configurers []GormConfigurer `group:"gorm_config"`
Translators []ErrorTranslator `group:"gorm_config"`
Tracer opentracing.Tracer `optional:"true"`
}
func provideGorm(di gormInitDI) *gorm.DB {
return NewGorm(func(cfg *GormConfig) {
cfg.Dialector = di.Dialector
cfg.LogLevel = di.Properties.Logging.Level
cfg.Configurers = append(cfg.Configurers, NewGormErrorHandlingConfigurer(di.Translators...))
if di.Tracer != nil {
cfg.Configurers = append(cfg.Configurers, NewGormTracingConfigurer(di.Tracer))
}
cfg.Configurers = append(cfg.Configurers, di.Configurers...)
if di.Properties.Logging.SlowThreshold > 0 {
cfg.LogSlowQueryThreshold = time.Duration(di.Properties.Logging.SlowThreshold)
}
})
}
func gormErrTranslatorProvider() fx.Annotated {
return fx.Annotated{
Group: GormConfigurerGroup,
Target: func() ErrorTranslator {
return NewGormErrorTranslator()
},
}
}
/**************************
Initialize
***************************/
type regDI struct {
fx.In
HealthRegistrar health.Registrar `optional:"true"`
GormDB *gorm.DB `optional:"true"`
}
func registerHealth(di regDI) {
if di.HealthRegistrar == nil || di.GormDB == nil {
return
}
di.HealthRegistrar.MustRegister(&DbHealthIndicator{
db: di.GormDB,
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package cockroach
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/data/postgresql"
"go.uber.org/fx"
"gorm.io/gorm"
)
type GormDbCreator struct {
dbName string
}
func (g *GormDbCreator) Order() int {
return postgresql.DBCreatorPostgresOrder - 1
}
func (g *GormDbCreator) CreateDatabaseIfNotExist(ctx context.Context, db *gorm.DB) error {
result := db.WithContext(ctx).Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", db.Statement.Quote(g.dbName)))
if result.Error != nil && errors.Is(result.Error, data.ErrorInsufficientPrivilege) {
logger.Warnf("Skipped creating database because %v", result.Error)
return nil
}
return result.Error
}
func NewGormDbCreator(properties data.DataProperties) data.DbCreator {
return &GormDbCreator{
dbName: properties.DB.Database,
}
}
func newAnnotatedGormDbCreator() fx.Annotated {
return fx.Annotated{
Group: data.GormConfigurerGroup,
Target: NewGormDbCreator,
}
}
package cockroach
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/data/postgresql"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
var logger = log.New("cockroach")
var Module = &bootstrap.Module{
Name: "cockroach",
Precedence: bootstrap.DatabasePrecedence,
Options: []fx.Option{
fx.Provide(newAnnotatedGormDbCreator()),
},
Modules: []*bootstrap.Module{
postgresql.Module,
},
}
func Use() {
bootstrap.Register(Module)
}
package postgresql
import (
"context"
"github.com/cisco-open/go-lanai/pkg/data"
"go.uber.org/fx"
"gorm.io/gorm"
)
const (
DBCreatorPostgresOrder = iota
)
type NoOpDbCreator struct{}
func (g NoOpDbCreator) Order() int {
return DBCreatorPostgresOrder
}
func (g NoOpDbCreator) CreateDatabaseIfNotExist(ctx context.Context, db *gorm.DB) error {
// postgres can't connect to database if it doesn't exist, nothing to do here
return nil
}
func NewGormDbCreator() data.DbCreator {
return &NoOpDbCreator{}
}
func newAnnotatedGormDbCreator() fx.Annotated {
return fx.Annotated{
Group: data.GormConfigurerGroup,
Target: NewGormDbCreator,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package postgresql
import (
"context"
"errors"
"fmt"
"regexp"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/jackc/pgx/v5/pgconn"
"github.com/lib/pq"
)
var (
dataIntegrityRegexp = regexp.MustCompile(`\((?P<col>[^()]+)\) *= *\((?P<value>[^()]*)\)`)
)
// PostgresErrorTranslator implements data.ErrorTranslator
// it translates pq.Error and pgconn.PgError to data.DataError
// Note: cockroach uses gorm.io/driver/postgres, which internally uses github.com/jackc/pgx
// Ref:
// - Postgres Error: https://www.postgresql.org/docs/11/protocol-error-fields.html
// - Postgres Error Code: https://www.postgresql.org/docs/11/errcodes-appendix.html
type PostgresErrorTranslator struct{}
func NewPqErrorTranslator() data.ErrorTranslator {
return data.DefaultGormErrorTranslator{
ErrorTranslator: PostgresErrorTranslator{},
}
}
func (t PostgresErrorTranslator) Order() int {
return 0
}
func (t PostgresErrorTranslator) Translate(_ context.Context, err error) error {
var ec int64
//nolint:errorlint // we don't consider wrapped error here
switch e := err.(type) {
case *pgconn.PgError:
ec = t.translateErrorCode(e.Code)
case *pq.Error:
ec = t.translateErrorCode(string(e.Code))
default:
return err
}
de := data.NewDataError(ec, err)
switch {
case errors.Is(de, data.ErrorDuplicateKey):
return t.translateDuplicateKeyErrorMessage(de)
default:
return de
}
}
func (t PostgresErrorTranslator) translateDuplicateKeyErrorMessage(e data.DataError) data.DataError {
cause := e.Cause()
var details string
//nolint:errorlint // we don't consider wrapped error here
switch ce := cause.(type) {
case *pgconn.PgError:
details = ce.Detail
case *pq.Error:
details = ce.Detail
default:
return e
}
msg := "duplicate keys"
colMsg := ""
valMsg := ""
matches := dataIntegrityRegexp.FindStringSubmatch(details)
for i, name := range dataIntegrityRegexp.SubexpNames() {
if i >= len(matches) {
break
}
if name == "value" {
if matches[i] != "" {
valMsg = fmt.Sprintf("duplicate keys: %s", matches[i])
}
}
if name == "col" {
if matches[i] != "" {
colMsg = fmt.Sprintf("duplicate key in column: %s", matches[i])
}
}
}
if colMsg != "" {
msg = fmt.Sprintf("%s; %s", msg, colMsg)
}
if valMsg != "" {
msg = fmt.Sprintf("%s; %s", msg, valMsg)
}
return e.WithCause(e.Cause(), msg)
}
// translateErrorCode translate postgres error code to data.DataError code
// ref https://www.postgresql.org/docs/11/errcodes-appendix.html
func (t PostgresErrorTranslator) translateErrorCode(code string) int64 {
// currently we handle selected error classes
// TODO more detailed error translation
var errCls string
if len(code) == 5 {
errCls = code[:2]
}
// for now based on class
switch errCls {
// data.ErrorSubTypeCodeQuery
case "22", "26", "42":
switch code {
case "42501":
return data.ErrorCodeInsufficientPrivilege
default:
return data.ErrorSubTypeCodeQuery
}
// data.ErrorSubTypeCodeDataRetrieval
case "24":
return data.ErrorCodeIncorrectRecordCount
// data.ErrorSubTypeCodeDataIntegrity
case "21", "23", "27":
switch code {
case "23505":
return data.ErrorCodeDuplicateKey
default:
return data.ErrorCodeConstraintViolation
}
// data.ErrorSubTypeCodeTransaction
case "25", "2D", "2d", "3B", "3b", "40":
return data.ErrorCodeInvalidTransaction
// data.ErrorSubTypeCodeSecurity
case "28":
return data.ErrorCodeAuthenticationFailed
// data.ErrorSubTypeCodeConcurrency
case "55":
return data.ErrorSubTypeCodeConcurrency
// data.ErrorTypeCodeTransient
case "53":
return data.ErrorTypeCodeTransient
}
return data.ErrorTypeCodeUncategorizedServerSide
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package postgresql
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/data"
"go.uber.org/fx"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"strings"
)
const (
dsKeyHost = "host"
dsKeyPort = "port"
dsKeyDB = "dbname"
dsKeySslMode = "sslmode"
dsKeyUsername = "user"
dsKeyPassword = "password"
dsKeySslRootCert = "sslrootcert"
dsKeySslCert = "sslcert"
dsKeySslKey = "sslkey"
dsKeySslKeyPass = "sslpassword "
)
type initDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
Properties data.DataProperties
CertsManager certs.Manager `optional:"true"`
}
func NewGormDialetor(di initDI) gorm.Dialector {
//"host=localhost user=root password=root dbname=idm port=26257 sslmode=disable"
options := map[string]interface{}{
dsKeyHost: di.Properties.DB.Host,
dsKeyPort: di.Properties.DB.Port,
dsKeyDB: di.Properties.DB.Database,
dsKeySslMode: di.Properties.DB.SslMode,
}
// Setup TLS properties
if di.Properties.DB.Tls.Enable && di.CertsManager != nil {
source, e := di.CertsManager.Source(di.AppContext, certs.WithSourceProperties(&di.Properties.DB.Tls.Certs))
if e == nil {
certFiles, e := source.Files(di.AppContext)
if e == nil {
options[dsKeySslRootCert] = strings.Join(certFiles.RootCAPaths, " ")
options[dsKeySslCert] = certFiles.CertificatePath
options[dsKeySslKey] = certFiles.PrivateKeyPath
if len(certFiles.PrivateKeyPassphrase) != 0 {
options[dsKeySslKeyPass] = certFiles.PrivateKeyPassphrase
}
}
} else {
logger.Errorf("Failed to provision TLS certificates: %v", e)
}
}
if di.Properties.DB.Username != "" {
options[dsKeyUsername] = di.Properties.DB.Username
options[dsKeyPassword] = di.Properties.DB.Password
}
config := postgres.Config{
//DriverName: "postgres",
DSN: toDSN(options),
}
return NewGormDialectorWithConfig(config)
}
func toDSN(options map[string]interface{}) string {
opts := []string{}
for k, v := range options {
opt := fmt.Sprintf("%s=%v", k, v)
opts = append(opts, opt)
}
return strings.Join(opts, " ")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package postgresql
import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
/*************************
Custom GormDialector
*************************/
type GormDialector struct {
postgres.Dialector
}
func NewGormDialectorWithConfig(config postgres.Config) *GormDialector {
return &GormDialector{
Dialector: *postgres.New(config).(*postgres.Dialector),
}
}
func (d GormDialector) Migrator(db *gorm.DB) gorm.Migrator {
return NewGormMigrator(db, d)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package postgresql
import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/migrator"
)
/*************************
Custom Migrator
*************************/
// GormMigrator
// Inverted index support:
// for now, use PostgreSQL-compatible syntax: https://www.cockroachlabs.com/docs/v20.2/inverted-indexes#creation
type GormMigrator struct {
postgres.Migrator
}
func NewGormMigrator(db *gorm.DB, dialector gorm.Dialector) *GormMigrator {
return &GormMigrator{
Migrator: postgres.Migrator{
Migrator: migrator.Migrator{
Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
},
},
},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package postgresql
import "go.uber.org/fx"
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/log"
)
var logger = log.New("postgresql")
var Module = &bootstrap.Module{
Name: "postgres-compatible",
Precedence: bootstrap.DatabasePrecedence,
Options: []fx.Option{
fx.Provide(NewGormDialetor,
pqErrorTranslatorProvider(),
newAnnotatedGormDbCreator(),
),
},
}
func Use() {
bootstrap.Register(Module)
}
/**************************
Provider
***************************/
func pqErrorTranslatorProvider() fx.Annotated {
return fx.Annotated{
Group: data.GormConfigurerGroup,
Target: NewPqErrorTranslator,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package data
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"time"
)
const (
PropertiesPrefix = "data"
)
type DataProperties struct {
Logging LoggingProperties `json:"logging"`
Transaction TransactionProperties `json:"transaction"`
DB DatabaseProperties `json:"db"`
}
type TransactionProperties struct {
MaxRetry int `json:"max-retry"`
}
type LoggingProperties struct {
Level log.LoggingLevel `json:"level"`
SlowThreshold utils.Duration `json:"slow-threshold"`
}
type DatabaseProperties struct {
Host string `json:"host"`
Port int `json:"port"`
Database string `json:"database"`
Username string `json:"username"`
Password string `json:"password"`
SslMode string `json:"sslmode"`
Tls TLS `json:"tls"`
}
type TLS struct {
Enable bool `json:"enabled"`
Certs certs.SourceProperties `json:"certs"`
}
// NewDataProperties create a DataProperties with default values
func NewDataProperties() *DataProperties {
return &DataProperties{
Logging: LoggingProperties{
Level: log.LevelWarn,
SlowThreshold: utils.Duration(15 * time.Second),
},
Transaction: TransactionProperties{
MaxRetry: 5,
},
DB: DatabaseProperties{
Host: "localhost",
Port: 26257,
Username: "root",
Password: "",
SslMode: "disable",
},
}
}
func BindDataProperties(ctx *bootstrap.ApplicationContext) DataProperties {
props := NewDataProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind DataProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package repo
import (
"context"
"database/sql"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/data/tx"
"gorm.io/gorm"
)
type TxWithGormFunc func(ctx context.Context, tx *gorm.DB) error
type GormApi interface {
DB(ctx context.Context) *gorm.DB
Transaction(ctx context.Context, txFunc TxWithGormFunc, opts ...*sql.TxOptions) error
WithSession(config *gorm.Session) GormApi
}
type gormApi struct {
db *gorm.DB
txManager tx.GormTxManager
}
func newGormApi(db *gorm.DB, txManager tx.GormTxManager) GormApi {
return gormApi{
db: db,
txManager: txManager.WithDB(db),
}
}
func (g gormApi) WithSession(config *gorm.Session) GormApi {
db := g.db.Session(config)
return gormApi{
db: db,
txManager: g.txManager.WithDB(db),
}
}
func (g gormApi) DB(ctx context.Context) *gorm.DB {
// tx support
if t := tx.GormTxWithContext(ctx); t != nil {
return t
}
return g.db.WithContext(ctx)
}
func (g gormApi) Transaction(ctx context.Context, txFunc TxWithGormFunc, opts ...*sql.TxOptions) error {
return g.txManager.Transaction(ctx, func(c context.Context) error {
t := tx.GormTxWithContext(c)
if t == nil {
return data.NewDataError(data.ErrorCodeInvalidTransaction, "gorm Tx is not found in context")
}
return txFunc(c, t)
}, opts...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package repo
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/google/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"reflect"
)
const (
// e.g. *Model
typeModelPtr typeKey = iota
// e.g. Model
typeModel
// e.g. *[]Model
typeModelSlicePtr
// e.g. *[]*Model{}
typeModelPtrSlicePtr
// e.g. []Model
typeModelSlice
// e.g. []*Model
typeModelPtrSlice
// map[string]interface{}
typeGenericMap
)
const (
errTmplInvalidCrudValue = `%T is not a valid value for %s, requires %s`
errTmplInvalidCrudModel = "%T is not a valid model for %s, requires %s"
)
type typeKey int
var (
singleModelRead = utils.NewSet(typeModelPtr)
multiModelRead = utils.NewSet(typeModelPtrSlicePtr, typeModelSlicePtr)
singleModelWrite = utils.NewSet(typeModelPtr, typeModel)
//multiModelWrite = utils.NewSet(typeModelPtrSlice, typeModelSlice, typeModelPtrSlicePtr, typeModelSlicePtr)
genericModelWrite = utils.NewSet(
typeModelPtr,
typeModelPtrSlice,
typeGenericMap,
typeModelPtrSlicePtr,
typeModelSlice,
typeModelSlicePtr,
typeModel,
)
)
// GormCrud implements CrudRepository and can be embedded into any repositories using gorm as ORM
type GormCrud struct {
GormApi
GormMetadata
}
func newGormCrud(api GormApi, model interface{}) (*GormCrud, error) {
// Note we uses raw db here to leverage internal schema cache
meta, e := newGormMetadata(api.DB(context.Background()), model)
if e != nil {
return nil, e
}
ret := &GormCrud{
GormApi: api,
GormMetadata: meta,
}
return ret, nil
}
func (g GormCrud) FindById(ctx context.Context, dest interface{}, id interface{}, options ...Option) error {
if !g.isSupportedValue(dest, singleModelRead) {
return ErrorInvalidCrudParam.
WithMessage(errTmplInvalidCrudValue, dest, "FindById", "*Struct")
}
// TODO verify this using index key
switch v := id.(type) {
case string:
if uid, e := uuid.Parse(v); e == nil {
id = uid
}
case *string:
if uid, e := uuid.Parse(*v); e == nil {
id = uid
}
}
return execute(ctx, g.GormApi.DB(ctx), nil, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Take(dest, id)
})
}
func (g GormCrud) FindAll(ctx context.Context, dest interface{}, options ...Option) error {
if !g.isSupportedValue(dest, multiModelRead) {
return ErrorInvalidCrudParam.
WithMessage(errTmplInvalidCrudValue, dest, "FindAll", "*[]Struct or *[]*Struct")
}
return execute(ctx, g.GormApi.DB(ctx), nil, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Find(dest)
})
}
func (g GormCrud) FindOneBy(ctx context.Context, dest interface{}, condition Condition, options ...Option) error {
if !g.isSupportedValue(dest, singleModelRead) {
return ErrorInvalidCrudParam.
WithMessage(errTmplInvalidCrudValue, dest, "FindOneBy", "*Struct")
}
return execute(ctx, g.GormApi.DB(ctx), condition, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Take(dest)
})
}
func (g GormCrud) FindAllBy(ctx context.Context, dest interface{}, condition Condition, options ...Option) error {
if !g.isSupportedValue(dest, multiModelRead) {
return ErrorInvalidCrudParam.
WithMessage(errTmplInvalidCrudValue, dest, "FindAllBy", "*[]Struct or *[]*Struct")
}
return execute(ctx, g.GormApi.DB(ctx), condition, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Find(dest)
})
}
func (g GormCrud) CountAll(ctx context.Context, options ...Option) (int, error) {
var ret int64
e := execute(ctx, g.GormApi.DB(ctx), nil, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Count(&ret)
})
if e != nil {
return -1, e
}
return int(ret), nil
}
func (g GormCrud) CountBy(ctx context.Context, condition Condition, options ...Option) (int, error) {
var ret int64
e := execute(ctx, g.GormApi.DB(ctx), condition, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Count(&ret)
})
if e != nil {
return -1, e
}
return int(ret), nil
}
func (g GormCrud) Save(ctx context.Context, v interface{}, options ...Option) error {
if !g.isSupportedValue(v, genericModelWrite) {
return ErrorInvalidCrudParam.WithMessage(errTmplInvalidCrudValue, v, "Save", "*Struct or []*Struct or []Struct")
}
return execute(ctx, g.GormApi.DB(ctx), nil, options, nil, func(db *gorm.DB) *gorm.DB {
return db.Save(v)
})
}
func (g GormCrud) Create(ctx context.Context, v interface{}, options ...Option) error {
if !g.isSupportedValue(v, genericModelWrite) {
return ErrorInvalidCrudParam.WithMessage(errTmplInvalidCrudValue, v, "Create", "*Struct, []*Struct or []Struct")
}
return execute(ctx, g.GormApi.DB(ctx), nil, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Create(v)
})
}
func (g GormCrud) Update(ctx context.Context, model interface{}, v interface{}, options ...Option) error {
if !g.isSupportedValue(model, singleModelWrite) {
return ErrorInvalidCrudParam.
WithMessage(errTmplInvalidCrudModel, v, "Update", "*Struct or Struct")
}
return execute(ctx, g.GormApi.DB(ctx), nil, options, modelFunc(model), func(db *gorm.DB) *gorm.DB {
// note we use the actual model instead of template g.model
return db.Updates(v)
})
}
func (g GormCrud) Delete(ctx context.Context, v interface{}, options ...Option) error {
if !g.isSupportedValue(v, genericModelWrite) {
return ErrorInvalidCrudParam.WithMessage(errTmplInvalidCrudValue, v, "Delete", "*Struct, []Struct or []*Struct")
}
return execute(ctx, g.GormApi.DB(ctx), nil, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Delete(v)
})
}
func (g GormCrud) DeleteBy(ctx context.Context, condition Condition, options ...Option) error {
return execute(ctx, g.GormApi.DB(ctx), condition, options, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
return db.Delete(reflect.New(g.ModelType()).Interface())
})
}
func (g GormCrud) Truncate(ctx context.Context) error {
return execute(ctx, g.GormApi.DB(ctx), nil, nil, modelFunc(g.model), func(db *gorm.DB) *gorm.DB {
if e := db.Statement.Parse(g.model); e != nil {
_ = db.AddError(ErrorInvalidCrudModel.WithMessage("unable to parse table name for model %T", g.model))
return db
}
table := interface{}(db.Statement.TableExpr)
if db.Statement.TableExpr == nil {
table = db.Statement.Table
}
return db.Exec(fmt.Sprintf(`TRUNCATE TABLE %s CASCADE`, db.Statement.Quote(table)))
})
}
/*******************
Helpers
*******************/
func modelFunc(m interface{}) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
return db.Model(m)
}
}
func execute(_ context.Context, db *gorm.DB, condition Condition, opts []Option, preOptsFn, fn func(*gorm.DB) *gorm.DB) error {
// make a copy of option array
options := make([]Option, len(opts), len(opts)+1)
copy(options, opts)
if preOptsFn != nil {
options = append(options, priorityOption{order: order.Highest, wrapped: preOptsFn})
}
// prepare
var e error
if db, e = applyOptions(db, options); e != nil {
return e
}
if db, e = applyCondition(db, condition); e != nil {
return e
}
// execute
r := fn(db)
// post exec
switch r, e := applyPostExecOptions(r, options); {
case e != nil:
return e
case r.Error != nil:
return r.Error
default:
return nil
}
}
func optsToDBFuncs(opts []Option) ([]func(*gorm.DB) *gorm.DB, error) {
order.SortStable(opts, order.UnorderedMiddleCompare)
scopes := make([]func(*gorm.DB) *gorm.DB, 0, len(opts))
for _, v := range opts {
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array:
size := rv.Len()
slice := make([]Option, size)
for i := 0; i < size; i++ {
slice[i] = rv.Index(i).Interface()
}
sub, e := optsToDBFuncs(slice)
if e != nil {
return nil, e
}
scopes = append(scopes, sub...)
default:
switch opt := v.(type) {
case postExecOptions:
// postExecOptions is not counted as condition, ignore
continue
case priorityOption:
sub, e := optsToDBFuncs([]Option{opt.wrapped})
if e != nil {
return nil, e
}
scopes = append(scopes, sub...)
case delayedOption:
//SuppressWarnings go:S1871 we can use "opt.wrapped" here, but SONAR doesn't understand type switching
sub, e := optsToDBFuncs([]Option{opt.wrapped})
if e != nil {
return nil, e
}
scopes = append(scopes, sub...)
case gormOptions:
scopes = append(scopes, opt)
case func(*gorm.DB) *gorm.DB:
scopes = append(scopes, opt)
default:
return nil, ErrorUnsupportedOptions.WithMessage("unsupported Option %T", v)
}
}
}
return scopes, nil
}
func applyOptions(db *gorm.DB, opts []Option) (*gorm.DB, error) {
if len(opts) == 0 {
return db, nil
}
funcs, e := optsToDBFuncs(opts)
if e != nil {
return nil, e
}
// Note, we choose to apply funcs by our self instead of using db.Scopes(...),
// because we don't want to confuse GORM with other scopes added else where
for _, fn := range funcs {
db = fn(db)
}
return db, db.Error
}
func conditionToDBFuncs(condition Condition) ([]func(*gorm.DB) *gorm.DB, error) {
var scopes []func(*gorm.DB) *gorm.DB
switch cv := reflect.ValueOf(condition); cv.Kind() {
case reflect.Slice, reflect.Array:
size := cv.Len()
scopes = make([]func(*gorm.DB) *gorm.DB, 0, size)
for i := 0; i < size; i++ {
sub, e := conditionToDBFuncs(cv.Index(i).Interface())
if e != nil {
return nil, e
}
scopes = append(scopes, sub...)
}
default:
var scope func(*gorm.DB) *gorm.DB
switch where := condition.(type) {
case postExecOptions:
// postExecOptions is not counted as condition, scope is a noop
scope = func(db *gorm.DB) *gorm.DB { return db }
case gormOptions:
scope = where
case func(*gorm.DB) *gorm.DB:
scope = where
case clause.Where:
scope = func(db *gorm.DB) *gorm.DB {
return db.Clauses(where)
}
default:
scope = func(db *gorm.DB) *gorm.DB {
return db.Where(condition)
}
}
scopes = []func(*gorm.DB) *gorm.DB{scope}
}
return scopes, nil
}
func applyCondition(db *gorm.DB, condition Condition) (*gorm.DB, error) {
if condition == nil {
return db, nil
}
funcs, e := conditionToDBFuncs(condition)
if e != nil {
return nil, e
}
// Note, we choose to apply funcs by our self instead of using db.Scopes(...),
// because we don't want to confuse GORM with other scopes added else where
for _, fn := range funcs {
db = fn(db)
}
return db, db.Error
}
func postExecOptsToDBFuncs(opts []Option) ([]func(*gorm.DB) *gorm.DB, error) {
scopes := make([]func(*gorm.DB) *gorm.DB, 0, len(opts))
for _, v := range opts {
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array:
size := rv.Len()
slice := make([]Option, size)
for i := 0; i < size; i++ {
slice[i] = rv.Index(i).Interface()
}
sub, e := postExecOptsToDBFuncs(slice)
if e != nil {
return nil, e
}
scopes = append(scopes, sub...)
default:
switch opt := v.(type) {
case postExecOptions:
scopes = append(scopes, opt)
}
}
}
return scopes, nil
}
func applyPostExecOptions(db *gorm.DB, opts []Option) (*gorm.DB, error) {
if len(opts) == 0 {
return db, nil
}
funcs, e := postExecOptsToDBFuncs(opts)
if e != nil {
return nil, e
}
for _, fn := range funcs {
db = fn(db)
}
return db, db.Error
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package repo
import (
"github.com/cisco-open/go-lanai/pkg/data/tx"
"gorm.io/gorm"
)
type GormFactory struct {
db *gorm.DB
txManager tx.GormTxManager
api GormApi
}
func newGormFactory(db *gorm.DB, txManager tx.GormTxManager) Factory {
return &GormFactory{
db: db,
txManager: txManager,
api: newGormApi(db, txManager),
}
}
func (f GormFactory) NewCRUD(model interface{}, options...interface{}) CrudRepository {
api := f.NewGormApi(options...)
crud, e := newGormCrud(api, model)
if e != nil {
panic(e)
}
return crud
}
func (f GormFactory) NewGormApi(options...interface{}) GormApi {
api := f.api
for _, v := range options {
switch opt := v.(type) {
case gorm.Session:
api = api.WithSession(&opt)
case *gorm.Session:
api = api.WithSession(opt)
}
}
return api
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package repo
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
const (
maxUInt32 = int(^uint32(0))
)
type gormOptions func(*gorm.DB) *gorm.DB
// priorityOption is an option wrapper that guarantee to run before regular options
// priorityOption implements order.PriorityOrdered
type priorityOption struct {
order int
wrapped interface{}
}
func (o priorityOption) PriorityOrder() int {
return o.order
}
// delayedOption is an option wrapper that guarantee to run after regular options
// delayedOption implements order.Ordered
type delayedOption struct {
order int
wrapped interface{}
}
func (o delayedOption) Order() int {
return o.order
}
// postExecOptions is applied after SQL is executed. Mostly useful to assert/update result or update error
type postExecOptions func(*gorm.DB) *gorm.DB
/********************
Util Functions
********************/
// MustApplyOptions takes a slice of Option and apply it to the given gorm.DB.
// This function is intended for custom repository implementations.
// The function panic if any Option is not supported type
func MustApplyOptions(db *gorm.DB, opts ...Option) *gorm.DB {
order.SortStable(opts, order.UnorderedMiddleCompare)
return AsGormScope(opts)(db)
}
// MustApplyConditions takes a slice of Condition and apply it to the given gorm.DB.
// This function is intended for custom repository implementations.
// The function panic if any Condition is not supported type
func MustApplyConditions(db *gorm.DB, conds ...Condition) *gorm.DB {
return AsGormScope(conds)(db)
}
// AsGormScope convert following types to a func(*gorm.DB)*gorm.DB:
// - Option or slice of Option
// - Condition or slice of Condition
// - func(*gorm.DB)*gorm.DB (noop)
// - slice of func(*gorm.DB)*gorm.DB
//
// This function is intended for custom repository implementations. The result can be used as "db.Scopes(result...)"
// The function panic on any type not listed above
func AsGormScope(i interface{}) func(*gorm.DB) *gorm.DB {
var funcs []func(*gorm.DB) *gorm.DB
var e error
switch v := i.(type) {
case func(*gorm.DB) *gorm.DB:
return v
case []func(*gorm.DB) *gorm.DB:
funcs = v
case []Option:
funcs, e = optsToDBFuncs(v)
case []Condition, clause.Where:
funcs, e = conditionToDBFuncs(Condition(i))
case gormOptions, priorityOption, delayedOption:
funcs, e = optsToDBFuncs([]Option{i})
case postExecOptions:
e = ErrorUnsupportedOptions.WithMessage("unsupported Option %T", v)
default:
funcs, e = conditionToDBFuncs(Condition(i))
}
// wrap up
switch {
case e != nil:
panic(e)
case len(funcs) == 0:
return func(db *gorm.DB) *gorm.DB { return db }
case len(funcs) == 1:
return funcs[0]
default:
return func(db *gorm.DB) *gorm.DB {
for _, fn := range funcs {
db = fn(db)
}
return db
}
}
}
/**************************
Options & Conditions
**************************/
// Or is a Condition that directly bridge parameters to (*gorm.DB).Or()
func Or(query interface{}, args ...interface{}) Condition {
return gormOptions(func(db *gorm.DB) *gorm.DB {
return db.Or(query, args...)
})
}
// Where is a Condition that directly bridge parameters to (*gorm.DB).Where()
func Where(query interface{}, args ...interface{}) Condition {
return gormOptions(func(db *gorm.DB) *gorm.DB {
return db.Where(query, args...)
})
}
// Joins is an Option for Find* operations, typically used to populate "ToOne" relationship using JOIN clause
// e.g. CrudRepository.FindById(ctx, &user, Joins("Status"))
//
// When used on "ToMany", JOIN query is usually used instead of field
// e.g. CrudRepository.FindById(ctx, &user, Joins("JOIN address ON address.user_id = users.id AND address.country = ?", "Canada"))
func Joins(query string, args ...interface{}) Option {
return gormOptions(func(db *gorm.DB) *gorm.DB {
return db.Joins(query, args...)
})
}
// Preload is an Option for Find* operations, typically used to populate relationship fields using separate queries
// e.g.
// CrudRepository.FindAll(ctx, &user, Preload("Roles.Permissions"))
// CrudRepository.FindAll(ctx, &user, Preload("Roles", "role_name NOT IN (?)", "excluded"))
func Preload(query string, args ...interface{}) Option {
return gormOptions(func(db *gorm.DB) *gorm.DB {
return db.Preload(query, args...)
})
}
// Omit is an Option specifying fields that you want to ignore when creating, updating and querying.
// When supported by gorm.io, this Option is a direct bridge to (*gorm.DB).Omit().
// Please see https://gorm.io/docs/ for detailed usage
func Omit(fields ...string) Option {
return gormOptions(func(db *gorm.DB) *gorm.DB {
return db.Omit(fields...)
})
}
// Select is an Option specify fields that you want when querying, creating, updating.
// This Option has different meaning when used for different operations (query vs create vs update vs save vs delete)
// When supported by gorm.io, this Option is a direct bridge to (*gorm.DB).Select().
// // Please see https://gorm.io/docs/ for detailed usage
func Select(query interface{}, args ...interface{}) Option {
return gormOptions(func(db *gorm.DB) *gorm.DB {
return db.Select(query, args...)
})
}
// Page is an Option specifying pagination when retrieve records from database
// page: page number started with 0
// size: page size (# of records per page)
// e.g.
// CrudRepository.FindAll(ctx, &user, Page(2, 10))
// CrudRepository.FindAllBy(ctx, &user, Where(...), Page(2, 10))
func Page(page, size int) Option {
opt := gormOptions(func(db *gorm.DB) *gorm.DB {
offset := page * size
if offset < 0 || size <= 0 || offset+size >= maxUInt32 {
_ = db.AddError(ErrorInvalidPagination)
return db
}
db = db.Offset(offset).Limit(size)
// add default sorting to ensure fixed order
sort := clause.OrderByColumn{Column: clause.Column{Name: clause.PrimaryKey}}
db.Statement.AddClauseIfNotExists(clause.OrderBy{
Columns: []clause.OrderByColumn{sort},
})
return db
})
// we want to run this option AFTER any Sort or SortBy
return delayedOption{
order: order.Lowest,
wrapped: opt,
}
}
// Sort is an Option specifying order when retrieve records from database by using column.
// This Option is typically used together with Page option
// When supported by gorm.io, this Option is a direct bridge to (*gorm.DB).Order()
// e.g.
// CrudRepository.FindAll(ctx, &user, Page(2, 10), Sort("name DESC"))
// CrudRepository.FindAllBy(ctx, &user, Where(...), Page(2, 10), Sort(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}))
func Sort(value interface{}) Option {
return gormOptions(func(db *gorm.DB) *gorm.DB {
return db.Order(value)
})
}
// SortBy an Option similar to Sort, but specifying model's field name
// This Option also support order by direct "ToOne" relation's field when used together with Joins.
// e.g.
// CrudRepository.FindAll(ctx, &user, Joins("Profile"), Page(2, 10), SortBy("Profile.FirstName", false))
// CrudRepository.FindAllBy(ctx, &user, Where(...), Page(2, 10), SortBy("Username", true))
func SortBy(fieldName string, desc bool) Option {
return gormOptions(func(db *gorm.DB) *gorm.DB {
if e := requireSchema(db); e != nil {
_ = db.AddError(ErrorUnsupportedOptions.WithMessage("SortBy not supported in this usage: %v", e))
return db
}
col, e := toColumn(db.Statement.Schema, fieldName)
if e != nil {
_ = db.AddError(data.ErrorSortByUnknownColumn.
WithMessage("SortBy error: %v", e))
return db
}
return db.Order(clause.OrderByColumn{
Column: *col,
Desc: desc,
})
})
}
// ErrorOnZeroRows a post-exec option that force repository returns error in case of db.AffectedRows == 0
// This option is useful on certain operations such as CrudRepository.Delete, or CrudRepository.Update,
// which doesn't return error if there is no row get affected/deleted.
func ErrorOnZeroRows() Option {
// Implementation Note:
// Alternative way (probably the proper way) to implement this is to add "after *" callback that reads
// statement's settings and process result accordingly, and ErrorOnZeroRows() can be a regular gormOptions that put a flag
// in statement's settings.
// The callback approach above would allow our ErrorTranslator to intercept the set error. But for this particular
// use case, it doesn't matter because we don't need to translate data.ErrorRecordNotFound error
return postExecOptions(func(db *gorm.DB) *gorm.DB {
if db.Error == nil && db.RowsAffected == 0 {
db.Error = data.ErrorRecordNotFound
}
return db
})
}
/***********************
Helpers
***********************/
func requireSchema(db *gorm.DB) error {
switch {
case db.Statement.Schema == nil && db.Statement.Model == nil:
return fmt.Errorf("schema/model is not available")
case db.Statement.Schema == nil:
if e := db.Statement.Parse(db.Statement.Model); e != nil {
return fmt.Errorf("failed to parse schema - %v", e)
}
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package repo
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"reflect"
"strings"
)
// GormSchemaResolver extends SchemaResolver to expose more schema related functions
type GormSchemaResolver interface {
SchemaResolver
// Schema returns raw schema.Schema.
Schema() *schema.Schema
}
// GormMetadata implements GormSchemaResolver
type GormMetadata struct {
gormSchemaResolver
model interface{}
types map[reflect.Type]typeKey
}
func newGormMetadata(db *gorm.DB, model interface{}) (GormMetadata, error) {
if model == nil {
return GormMetadata{}, ErrorInvalidCrudModel.WithMessage("%T is not a valid model for gorm CRUD repository", model)
}
// cache some types
var sType reflect.Type
switch t := reflect.TypeOf(model); {
case t.Kind() == reflect.Struct:
sType = t
case t.Kind() == reflect.Ptr:
for ; t.Kind() == reflect.Ptr; t = t.Elem() {
// SuppressWarnings go:S108 empty block is intended
}
sType = t
}
if sType == nil {
return GormMetadata{}, ErrorInvalidCrudModel.WithMessage("%T is not a valid model for gorm CRUD repository", model)
}
pType := reflect.PointerTo(sType)
types := map[reflect.Type]typeKey{
pType: typeModelPtr,
sType: typeModel,
reflect.PointerTo(reflect.SliceOf(sType)): typeModelSlicePtr,
reflect.PointerTo(reflect.SliceOf(pType)): typeModelPtrSlicePtr,
reflect.SliceOf(sType): typeModelSlice,
reflect.SliceOf(pType): typeModelPtrSlice,
reflect.TypeOf(map[string]interface{}{}): typeGenericMap,
}
resolver, e := newGormSchemaResolver(db, model)
if e != nil {
return GormMetadata{}, e
}
return GormMetadata{
gormSchemaResolver: resolver,
model: reflect.New(sType).Interface(),
types: types,
}, nil
}
func (g GormMetadata) isSupportedValue(value interface{}, types utils.Set) bool {
t := reflect.TypeOf(value)
typ, ok := g.types[t]
return ok && types.Has(typ)
}
// gormSchemaResolver implements GormSchemaResolver
type gormSchemaResolver struct {
schema *schema.Schema
}
func newGormSchemaResolver(db *gorm.DB, model interface{}) (gormSchemaResolver, error) {
// pre-parse schema
if e := db.Statement.Parse(model); e != nil {
return gormSchemaResolver{}, ErrorInvalidCrudModel.WithMessage("failed to parse schema of [%T] - %v", model, e)
}
return gormSchemaResolver{
schema: db.Statement.Schema,
}, nil
}
func (g gormSchemaResolver) ModelType() reflect.Type {
return g.schema.ModelType
}
func (g gormSchemaResolver) ModelName() string {
return g.schema.Name
}
func (g gormSchemaResolver) Table() string {
return g.schema.Table
}
func (g gormSchemaResolver) ColumnName(fieldName string) string {
if f, _ := lookupField(g.schema, fieldName); f != nil {
return f.DBName
}
return ""
}
func (g gormSchemaResolver) ColumnDataType(fieldName string) string {
if f, _ := lookupField(g.schema, fieldName); f != nil {
return string(f.DataType)
}
return ""
}
func (g gormSchemaResolver) RelationshipSchema(fieldName string) SchemaResolver {
return relationshipSchema(g.schema, fieldName)
}
func (g gormSchemaResolver) Schema() *schema.Schema {
return g.schema
}
/*************************
Helpers
*************************/
func relationshipSchema(s *schema.Schema, fieldName string) SchemaResolver {
split := strings.Split(fieldName, ".")
if s = followRelationships(s, split); s != nil {
return gormSchemaResolver{
schema: s,
}
}
return nil
}
// followRelationships find schema following relationship field path, returns nil if it cannot follow
func followRelationships(s *schema.Schema, fieldPaths []string) *schema.Schema {
ret := s
for _, fieldName := range fieldPaths {
relation, ok := ret.Relationships.Relations[fieldName]
if !ok || relation == nil || relation.Schema == nil {
return nil
}
ret = relation.FieldSchema
}
return ret
}
// lookupField similar to schema.Schema.LookUpField, but priority to field name,
// this function also follow relationships, e.g. "OneToOneFieldName.FieldName"
func lookupField(s *schema.Schema, name string) (f *schema.Field, paths []string) {
split := strings.Split(name, ".")
switch len(split) {
case 0:
return nil, nil
case 1:
default:
paths = split[0 : len(split)-1]
if s = followRelationships(s, paths); s == nil {
return nil, nil
}
name = split[len(split)-1]
}
if field, ok := s.FieldsByName[name]; ok {
return field, paths
}
if field, ok := s.FieldsByDBName[name]; ok {
return field, paths
}
return nil, nil
}
func toColumn(s *schema.Schema, name string) (*clause.Column, error) {
f, paths := lookupField(s, name)
if f == nil {
return nil, fmt.Errorf("field with name [%s] is not found on model %s", name, s.Name)
}
table := clause.CurrentTable
if len(paths) != 0 {
table = strings.Join(paths, ".")
}
return &clause.Column{Table: table, Name: f.DBName}, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package repo
import (
"context"
)
var defaultUtils Utility
// GormUtils implements Utility interface
type GormUtils struct {
api GormApi
model interface{}
resolver GormSchemaResolver
}
func Utils(options ...interface{}) Utility {
if len(options) == 0 {
return defaultUtils
}
switch factory := globalFactory.(type) {
case *GormFactory:
return newGormUtils(factory, options...)
default:
panic("global repo factory is not set, unable to create Utility")
}
}
func newGormUtils(factory *GormFactory, options ...interface{}) *GormUtils {
api := factory.NewGormApi(options...)
return &GormUtils{
api: api,
}
}
func (g GormUtils) Model(model interface{}) Utility {
resolver, _ := newGormSchemaResolver(g.api.DB(context.Background()), model)
return &GormUtils{
api: g.api,
model: model,
resolver: resolver,
}
}
func (g GormUtils) ResolveSchema(ctx context.Context, model interface{}) (SchemaResolver, error) {
return newGormSchemaResolver(g.api.DB(ctx), model)
}
func (g GormUtils) CheckUniqueness(ctx context.Context, v interface{}, keys ...interface{}) (dups map[string]interface{}, err error) {
resolver, e := g.getSchemaResolver(ctx, v)
if e != nil {
return nil, e
}
return gormCheckUniqueness(ctx, g.api, resolver, v, keys)
}
func (g GormUtils) getSchemaResolver(ctx context.Context, v interface{}) (GormSchemaResolver, error) {
switch {
case g.resolver != nil :
return g.resolver, nil
default:
return newGormSchemaResolver(g.api.DB(ctx), v)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package repo
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
)
//var logger = log.New("DB.Repo")
var globalFactory Factory
var Module = &bootstrap.Module{
Name: "DB Repo",
Precedence: bootstrap.DatabasePrecedence,
Options: []fx.Option{
fx.Provide(newGormFactory),
fx.Provide(newGormApi),
fx.Invoke(initialize),
},
}
func initialize(factory Factory) {
globalFactory = factory
defaultUtils = newGormUtils(factory.(*GormFactory))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package repo
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"reflect"
"sort"
"strings"
)
// index is used for Utility to specify index key.
// If index.Fields is set, index.Cols are ignored
type index []*schema.IndexOption
func gormCheckUniqueness(ctx context.Context, g GormApi, resolver GormSchemaResolver, v interface{}, keys []interface{}) (dups map[string]interface{}, err error) {
// index keys override
var uniqueKeys []index
if len(keys) == 0 {
uniqueKeys = findUniqueIndexes(resolver.Schema())
} else if uniqueKeys, err = resolveIndexes(resolver.Schema(), keys); err != nil {
return nil, err
}
// where clause
models := toInterfaces(v)
var where clause.Where
switch exprs := uniquenessWhere(models, uniqueKeys); {
case len(exprs) == 0:
return nil, ErrorInvalidUtilityUsage.WithMessage("%s is not possible. No non-zero unique fields found in given model [%v]", "CheckUniqueness", v)
case len(exprs) == 1:
where = clause.Where{Exprs: []clause.Expression{exprs[0]}}
default:
where = clause.Where{Exprs: []clause.Expression{clause.Or(exprs...)}}
}
// fetch and parse result
existing := reflect.New(resolver.ModelType()).Interface()
switch rs := g.DB(ctx).Where(where).Take(existing); {
case errors.Is(rs.Error, gorm.ErrRecordNotFound):
return nil, nil
case rs.Error != nil:
return nil, rs.Error
}
// find duplicates
for _, m := range models {
dups = findDuplicateFields(m, existing, uniqueKeys)
if len(dups) != 0 {
break
}
}
pairs := make([]string, 0, len(dups))
for k, v := range dups {
pairs = append(pairs, fmt.Sprintf("%s=[%v]", k, v))
}
return dups, data.ErrorDuplicateKey.WithMessage("entity with following properties already exists: %s", strings.Join(pairs, ", "))
}
/************************
Helpers
************************/
func findUniqueIndexes(s *schema.Schema) []index {
indexes := s.ParseIndexes()
uniqueness := make([]index, 0, len(indexes))
for _, idx := range indexes {
switch idx.Class {
case "UNIQUE":
fields := make([]*schema.IndexOption, len(idx.Fields))
for i := range idx.Fields {
fields[i] = &idx.Fields[i]
}
if len(fields) != 0 {
uniqueness = append(uniqueness, fields)
}
}
}
return uniqueness
}
func resolveIndexes(s *schema.Schema, keys []interface{}) ([]index, error) {
ret := make([]index, 0, len(keys))
for _, k := range keys {
var idx index
var e error
switch v := k.(type) {
case string:
idx, e = asIndex(s, []string{v})
case []string:
idx, e = asIndex(s, v)
default:
return nil, ErrorInvalidUtilityUsage.WithMessage("Invalid key type %T", k)
}
if e != nil {
return nil, e
}
ret = append(ret, idx)
}
return ret, nil
}
func asIndex(s *schema.Schema, names []string) (index, error) {
ret := make(index, len(names))
for i, n := range names {
f, paths := lookupField(s, n)
switch {
case f == nil:
return nil, fmt.Errorf("field with name [%s] is not found on model %s", n, s.Name)
case len(paths) > 0:
return nil, fmt.Errorf("associations are not supported in this utils")
}
ret[i] = &schema.IndexOption{Field: f}
}
return ret, nil
}
func uniquenessWhere(models []interface{}, keys []index) (exprs []clause.Expression) {
for _, m := range models {
exprs = append(exprs, uniquenessExprs(reflect.ValueOf(m), keys)...)
}
return
}
func uniquenessExprs(modelV reflect.Value, keys []index) []clause.Expression {
sort.SliceStable(keys, func(i, j int) bool {
return len(keys[i]) < len(keys[j])
})
exprs := make([]clause.Expression, 0, len(keys))
modelV = reflect.Indirect(modelV)
for _, idx := range keys {
if expr, ok := compositeEqExpr(modelV, idx); ok {
exprs = append(exprs, expr)
}
}
return exprs
}
// compositeEqExpr returns false if
// 1. any index values are zero values
// 2. any column value is not found
// 3. len(cols) == 0
func compositeEqExpr(modelV reflect.Value, idx index) (clause.Expression, bool) {
andExprs := make([]clause.Expression, len(idx))
for i, f := range idx {
v, ok := extractValue(modelV, f.Field)
if !ok || !v.IsValid() || v.IsZero() {
return nil, false
}
andExprs[i] = clause.Eq{
Column: clause.Column{Name: f.DBName},
Value: v.Interface(),
}
}
switch {
case len(andExprs) == 0:
return nil, false
case len(andExprs) == 1:
return andExprs[0], true
default:
return clause.And(andExprs...), true
}
}
// findDuplicateFields compare fields and returns fields that left and right are same
func findDuplicateFields(left interface{}, right interface{}, keys []index) map[string]interface{} {
dups := map[string]interface{}{}
leftV := reflect.Indirect(reflect.ValueOf(left))
rightV := reflect.Indirect(reflect.ValueOf(right))
for _, idx := range keys {
for _, f := range idx {
lVal, lok := extractValue(leftV, f.Field)
if !lok {
continue
}
rVal, rok := extractValue(rightV, f.Field)
if !rok {
continue
}
if reflect.DeepEqual(lVal.Interface(), rVal.Interface()) {
dups[f.Name] = lVal.Interface()
}
}
}
return dups
}
func extractValue(modelV reflect.Value, f *schema.Field) (reflect.Value, bool) {
switch modelV.Kind() {
case reflect.Map:
for i := modelV.MapRange(); i.Next(); {
k, ok := i.Key().Interface().(string)
if ok && (k == f.Name || k == f.DBName) {
return i.Value(), true
}
}
return reflect.Value{}, false
case reflect.Struct:
return modelV.FieldByIndex(f.StructField.Index), true
default:
return reflect.Value{}, false
}
}
func toInterfaces(v interface{}) (ret []interface{}) {
rv := reflect.Indirect(reflect.ValueOf(v))
switch rv.Kind() {
case reflect.Slice:
ret = make([]interface{}, rv.Len())
for i := 0; i < rv.Len(); i++ {
ret[i] = rv.Index(i).Interface()
}
return ret
case reflect.Struct, reflect.Map:
return []interface{}{v}
default:
return nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tx
import (
"context"
"database/sql"
)
//goland:noinspection GoNameStartsWithPackageName
type TxFunc func(ctx context.Context) error
//goland:noinspection GoNameStartsWithPackageName
type TxManager interface {
Transaction(ctx context.Context, tx TxFunc, opts ...*sql.TxOptions) error
}
// ManualTxManager defines interfaces for manual transaction management
// if any methods returns an error, the returned context should be disgarded
type ManualTxManager interface {
Begin(ctx context.Context, opts ...*sql.TxOptions) (context.Context, error)
Rollback(ctx context.Context) (context.Context, error)
Commit(ctx context.Context) (context.Context, error)
SavePoint(ctx context.Context, name string) (context.Context, error)
RollbackTo(ctx context.Context, name string) (context.Context, error)
}
//goland:noinspection GoNameStartsWithPackageName
type TxContext interface {
Parent() context.Context
}
type txBacktraceCtxKey struct{}
var ctxKeyBeginCtx = txBacktraceCtxKey{}
// txContext helps ManualTxManager to backtrace context used for ManualTxManager.Begin
type txContext struct {
context.Context
}
// newGormTxContext will check if the given context.Context is a TxContext. If so,
// It will increment the nestLevel of the new TxContext.
func newGormTxContext(ctx context.Context) txContext {
return txContext{
Context: ctx,
}
}
func (c txContext) Value(key interface{}) interface{} {
if k, ok := key.(txBacktraceCtxKey); ok && k == ctxKeyBeginCtx {
return c.Context
}
return c.Context.Value(key)
}
func (c txContext) Parent() context.Context {
return c.Context
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tx
import (
"context"
"database/sql"
"errors"
"github.com/cisco-open/go-lanai/pkg/data"
"gorm.io/gorm"
)
var (
ErrExceededMaxRetries = errors.New("exceeded maximum number of retries")
)
// DefaultExecuter executes the default behaviour for the TxManager and ManualTxManager
// It is possible to define a custom Executer and swap it with this DefaultExecuter.
// To do so, see tx/package.go and the TransactionExecuter in the fx.In of the provideGormTxManager
type DefaultExecuter struct {
// maxRetries can be defined by using the FxTransactionExecuterOption and MaxRetries option
maxRetries int
}
func NewDefaultExecuter(options ...TransactionExecuterOption) TransactionExecuter {
var opts TransactionExecuterOptions
for _, o := range options {
o(&opts)
}
return &DefaultExecuter{
maxRetries: opts.MaxRetries,
}
}
func (r *DefaultExecuter) ExecuteTx(ctx context.Context, db *gorm.DB, opt *sql.TxOptions, txFunc TxFunc) error {
retryCount := 0
// if we're in a transaction, make sure to use that db instead
if gormContext, ok := ctx.(GormContext); ok {
db = gormContext.DB()
}
for {
err := db.Transaction(func(txDb *gorm.DB) error {
txErr := txFunc(NewGormTxContext(ctx, txDb)) //nolint:contextcheck // this is equivalent to context.WithXXX
return txErr
}, opt)
if err == nil {
return nil
}
if !ErrIsRetryable(err) {
return err
}
retryCount++
if r.maxRetries > 0 && retryCount > r.maxRetries {
return ErrExceededMaxRetries
}
}
}
func (r *DefaultExecuter) Begin(ctx context.Context, db *gorm.DB, opts ...*sql.TxOptions) (context.Context, error) {
//if we're in a transaction, make sure to use that db instead
if gormContext, ok := ctx.(GormContext); ok {
db = gormContext.DB()
}
tx := db.Begin(opts...)
if tx.Error != nil {
return ctx, tx.Error
}
return NewGormTxContext(ctx, tx), nil
}
func (r *DefaultExecuter) Rollback(ctx context.Context) (context.Context, error) {
e := DoWithDB(ctx, func(db *gorm.DB) *gorm.DB {
return db.Rollback()
})
if e != nil {
return ctx, e
}
if tc, ok := ctx.(TxContext); ok && tc.Parent() != nil {
return tc.Parent(), nil
}
return ctx, data.NewDataError(data.ErrorCodeInvalidTransaction, ErrTmplSPFailure)
}
func (r *DefaultExecuter) Commit(ctx context.Context) (context.Context, error) {
e := DoWithDB(ctx, func(db *gorm.DB) *gorm.DB {
return db.Commit()
})
if e != nil {
return ctx, e
}
if tc, ok := ctx.(TxContext); ok && tc.Parent() != nil {
return tc.Parent(), nil
}
return ctx, data.NewDataError(data.ErrorCodeInvalidTransaction, ErrTmplSPFailure)
}
func (r *DefaultExecuter) SavePoint(ctx context.Context, name string) (context.Context, error) {
e := DoWithDB(ctx, func(db *gorm.DB) *gorm.DB {
return db.SavePoint(name)
})
if e != nil {
return ctx, e
}
if tc, ok := ctx.(TxContext); ok && tc.Parent() != nil {
return ctx, nil
}
return ctx, data.NewDataError(data.ErrorCodeInvalidTransaction, ErrTmplSPFailure)
}
func (r *DefaultExecuter) RollbackTo(ctx context.Context, name string) (context.Context, error) {
e := DoWithDB(ctx, func(db *gorm.DB) *gorm.DB {
return db.RollbackTo(name)
})
if e != nil {
return ctx, e
}
if tc, ok := ctx.(TxContext); ok && tc.Parent() != nil {
return ctx, nil
}
return ctx, data.NewDataError(data.ErrorCodeInvalidTransaction, ErrTmplSPFailure)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tx
import (
"context"
"database/sql"
"errors"
"gorm.io/gorm"
)
const (
ErrTmplSPFailure = `SavePoint failed. did you pass along the context provided by Begin(...)?`
)
type GormTxManager interface {
TxManager
WithDB(*gorm.DB) GormTxManager
}
type GormContext interface {
DB() *gorm.DB
}
var (
ctxKeyGorm = gormCtxKey{}
)
type gormCtxKey struct{}
type gormTxContext struct {
txContext
db *gorm.DB
}
// NewGormTxContext will wrap the given Context and *gorm.DB in a gormTxContext
func NewGormTxContext(ctx context.Context, db *gorm.DB) context.Context {
return gormTxContext{
txContext: newGormTxContext(ctx),
db: db,
}
}
func (c gormTxContext) Value(key interface{}) interface{} {
if k, ok := key.(gormCtxKey); ok && k == ctxKeyGorm {
return c.db
}
return c.txContext.Value(key)
}
func (c gormTxContext) DB() *gorm.DB {
return c.db
}
func GormTxWithContext(ctx context.Context) (tx *gorm.DB) {
if c, ok := ctx.(GormContext); ok && c.DB() != nil {
return c.DB().WithContext(ctx)
}
if db, ok := ctx.Value(ctxKeyGorm).(*gorm.DB); ok {
return db.WithContext(ctx)
}
return nil
}
type TransactionExecuterOptions struct {
MaxRetries int
}
type TransactionExecuterOption func(options *TransactionExecuterOptions)
// MaxRetries will return a TransactionExecuterOption of type OrderedTransactionExecuterOption
func MaxRetries(maxRetries int, order int) TransactionExecuterOption {
return func(options *TransactionExecuterOptions) {
options.MaxRetries = maxRetries
}
}
type TransactionExecuter interface {
ExecuteTx(context.Context, *gorm.DB, *sql.TxOptions, TxFunc) error
Begin(ctx context.Context, db *gorm.DB, opts ...*sql.TxOptions) (context.Context, error)
Rollback(ctx context.Context) (context.Context, error)
Commit(ctx context.Context) (context.Context, error)
SavePoint(ctx context.Context, name string) (context.Context, error)
RollbackTo(ctx context.Context, name string) (context.Context, error)
}
// gormTxManager implements TxManager, ManualTxManager and GormTxManager
type gormTxManager struct {
db *gorm.DB
txExecuter TransactionExecuter
}
func newGormTxManager(db *gorm.DB, executer TransactionExecuter) *gormTxManager {
return &gormTxManager{
db: db,
txExecuter: executer,
}
}
func (m gormTxManager) WithDB(db *gorm.DB) GormTxManager {
return &gormTxManager{
db: db,
txExecuter: m.txExecuter,
}
}
func (m gormTxManager) Transaction(ctx context.Context, tx TxFunc, opts ...*sql.TxOptions) error {
// According to finisher_api.go, in the Begin() function, if len(opts) > 0, then it only
// uses the opts[0] as the option
var opt *sql.TxOptions
if len(opts) > 0 {
opt = opts[0]
}
return m.txExecuter.ExecuteTx(ctx, m.db, opt, tx)
}
func (m gormTxManager) Begin(ctx context.Context, opts ...*sql.TxOptions) (context.Context, error) {
// ctx, and get DB out of ctx, and using it here
return m.txExecuter.Begin(ctx, m.db, opts...)
}
func (m gormTxManager) Rollback(ctx context.Context) (context.Context, error) {
return m.txExecuter.Rollback(ctx)
}
func (m gormTxManager) Commit(ctx context.Context) (context.Context, error) {
return m.txExecuter.Commit(ctx)
}
func (m gormTxManager) SavePoint(ctx context.Context, name string) (context.Context, error) {
return m.txExecuter.SavePoint(ctx, name)
}
func (m gormTxManager) RollbackTo(ctx context.Context, name string) (context.Context, error) {
return m.txExecuter.RollbackTo(ctx, name)
}
// gormTxManagerAdapter bridge a TxManager to GormTxManager with noop operation. Useful for testing
type gormTxManagerAdapter struct {
TxManager
}
func (a gormTxManagerAdapter) WithDB(_ *gorm.DB) GormTxManager {
return a
}
func DoWithDB(ctx context.Context, fn func(*gorm.DB) *gorm.DB) error {
if gc, ok := ctx.(GormContext); ok {
if t := gc.DB(); t != nil {
r := fn(t)
return r.Error
}
}
return nil
}
// The below code is taken from crdb/tx.go in the crdb package
func ErrIsRetryable(err error) bool {
// We look for the standard PG errcode SerializationFailureError:40001
code := errCode(err)
return code == "40001"
}
func errCode(err error) string {
var sqlErr errWithSQLState
if errors.As(err, &sqlErr) {
return sqlErr.SQLState()
}
return ""
}
// errWithSQLState is implemented by pgx (pgconn.PgError) and lib/pq
type errWithSQLState interface {
SQLState() string
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tx
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"go.uber.org/fx"
"gorm.io/gorm"
)
//var logger = log.New("DB.Tx")
var Module = &bootstrap.Module{
Name: "DB Tx",
Precedence: bootstrap.DatabasePrecedence,
Options: []fx.Option{
fx.Provide(provideGormTxManager),
fx.Invoke(setGlobalTxManager),
},
}
const (
FxTransactionExecuterOption = "TransactionExecuterOption"
)
type txDI struct {
fx.In
UnnamedTx TxManager `optional:"true"`
DB *gorm.DB `optional:"true"`
Executer TransactionExecuter `optional:"true"`
Options []TransactionExecuterOption `group:"TransactionExecuterOption"`
}
type txManagerOut struct {
fx.Out
Tx TxManager `name:"tx/TxManager"`
GormTx GormTxManager
}
func provideGormTxManager(di txDI) txManagerOut {
// due to limitation of uber/fx, we cannot override provider, which is not good for testing & mocking
// the workaround is we always use Named Provider as default,
// then bail the initialization if an Unnamed one is present
if di.UnnamedTx != nil {
if override, ok := di.UnnamedTx.(GormTxManager); ok {
return txManagerOut{Tx: override, GormTx: override}
} else {
// we should avoid this path
return txManagerOut{Tx: di.UnnamedTx, GormTx: gormTxManagerAdapter{TxManager: di.UnnamedTx}}
}
}
if di.DB == nil {
panic("default GormTxManager requires a *gorm.DB")
}
if di.Executer == nil {
di.Executer = NewDefaultExecuter(di.Options...)
}
m := newGormTxManager(di.DB, di.Executer)
return txManagerOut{
Tx: m,
GormTx: m,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tx
import (
"context"
"database/sql"
"github.com/cisco-open/go-lanai/pkg/data"
"go.uber.org/fx"
)
var txManager TxManager
type globalDI struct {
fx.In
Tx TxManager `name:"tx/TxManager"`
}
func setGlobalTxManager(di globalDI) {
txManager = di.Tx
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
func Transaction(ctx context.Context, tx TxFunc, opts ...*sql.TxOptions) error {
return mustGetTxManager().Transaction(ctx, tx, opts...)
}
// Begin start a transaction. the returned context.Context should be used for any transactioanl operations
// if returns an error, the returned context.Context should be disgarded
func Begin(ctx context.Context, opts ...*sql.TxOptions) (context.Context, error) {
return mustGetTxManager().(ManualTxManager).Begin(ctx, opts...)
}
// Rollback rollback a transaction. the returned context.Context is the original provided context when Begin is called
// if returns an error, the returned context.Context should be disgarded
func Rollback(ctx context.Context) (context.Context, error) {
return mustGetTxManager().(ManualTxManager).Rollback(ctx)
}
// Commit commit a transaction. the returned context.Context is the original provided context when Begin is called
// if returns an error, the returned context.Context should be disgarded
func Commit(ctx context.Context) (context.Context, error) {
return mustGetTxManager().(ManualTxManager).Commit(ctx)
}
// SavePoint works with RollbackTo and have to be within an transaction.
// the returned context.Context should be used for any transactioanl operations between corresponding SavePoint and RollbackTo
// if returns an error, the returned context.Context should be disgarded
func SavePoint(ctx context.Context, name string) (context.Context, error) {
return mustGetTxManager().(ManualTxManager).SavePoint(ctx, name)
}
// RollbackTo works with SavePoint and have to be within an transaction.
// the returned context.Context should be used for any transactioanl operations between corresponding SavePoint and RollbackTo
// if returns an error, the returned context.Context should be disgarded
func RollbackTo(ctx context.Context, name string) (context.Context, error) {
return mustGetTxManager().(ManualTxManager).RollbackTo(ctx, name)
}
func mustGetTxManager() TxManager {
if txManager == nil {
panic(data.NewDataError(data.ErrorCodeInternal, "TxManager is not initialized yet. Too early to call tx functions"))
}
return txManager
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package types
import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
const (
TagFilter = "filter"
)
// FixWhereClausesForStatementModifier applies special fix for
// db.Model(&model{}).Where(&model{f1:v1}).Or(&model{f2:v2})...
// Ref: https://github.com/go-gorm/gorm/issues/3627
// https://github.com/go-gorm/gorm/commit/9b2181199d88ed6f74650d73fa9d20264dd134c0#diff-e3e9193af67f3a706b3fe042a9f121d3609721da110f6a585cdb1d1660fd5a3c
// Important: utility function for go-lanai internal use
func FixWhereClausesForStatementModifier(stmt *gorm.Statement) {
cl, _ := stmt.Clauses["WHERE"]
if where, ok := cl.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
for _, expr := range where.Exprs {
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
cl.Expression = where
stmt.Clauses["WHERE"] = cl
break
}
}
}
}
// NoopStatementModifier used to be embedded of any StatementModifier implementation.
// This type implement dummy clause.Interface methods
type NoopStatementModifier struct {}
func (sm NoopStatementModifier) Name() string {
// noop
return ""
}
func (sm NoopStatementModifier) Build(clause.Builder) {
// noop
}
func (sm NoopStatementModifier) MergeClause(*clause.Clause) {
// noop
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package types
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"strconv"
"strings"
)
/****************************
Func
****************************/
// SkipBoolFilter is a gorm scope that can be used to skip filtering of FilterBool fields with given field names
// e.g. db.WithContext(ctx).Scopes(SkipBoolFilter("FieldName1", "FieldName2")).Find(...)
//
// To disable all FilterBool filtering, provide no params or "*"
// e.g. db.WithContext(ctx).Scopes(SkipBoolFilter()).Find(...)
//
// Note using this scope without context would panic
func SkipBoolFilter(filedNames ...string) func(*gorm.DB) *gorm.DB {
return func(tx *gorm.DB) *gorm.DB {
if tx.Statement.Context == nil {
panic("SkipBoolFilter used without context")
}
ctx := tx.Statement.Context
for _, filedName := range filedNames {
ctx = context.WithValue(ctx, ckFilterMode(filedName), fmDisabled)
}
if len(filedNames) == 0 {
ctx = context.WithValue(ctx, ckFilterMode("*"), fmDisabled)
}
tx.Statement.Context = ctx
return tx
}
}
// BoolFiltering is a gorm scope that change default/tag behavior of FilterBool field filtering with given field names
// e.g. db.WithContext(ctx).Scopes(BoolFiltering(false, "FieldName1", "FieldName2")).Find(...)
// would filter out any model with "FieldName1" or "FieldName2" equals to "false"
//
// To override all FilterBool filtering, provide no params or "*"
// e.g. db.WithContext(ctx).Scopes(BoolFiltering()).Find(...)
//
// Note using this scope without context would panic
func BoolFiltering(filterVal bool, filedNames ...string) func(*gorm.DB) *gorm.DB {
return func(tx *gorm.DB) *gorm.DB {
if tx.Statement.Context == nil {
panic("BoolFiltering used without context")
}
mode := fmPositive
if !filterVal {
mode = fmNegative
}
ctx := tx.Statement.Context
for _, filedName := range filedNames {
ctx = context.WithValue(ctx, ckFilterMode(filedName), mode)
}
if len(filedNames) == 0 {
ctx = context.WithValue(ctx, ckFilterMode("*"), mode)
}
tx.Statement.Context = ctx
return tx
}
}
/****************************
Types
****************************/
type ckFilterMode string
const (
fmPositive filterMode = iota
fmNegative
fmDisabled
)
// filterMode enum of possible values fm*
type filterMode int
// FilterBool implements
// - schema.GormDataTypeInterface
// - schema.QueryClausesInterface
// this data type adds "WHERE" clause in SELECT statements for filtering out models based on this field's value
//
// FilterBool by default filter out true values (WHERE filter_bool_col IS NOT TRUE AND ....).
// this behavior can be changed to using tag `filter:"<-|true|false>"`
// - `filter:"-"`: disables the filtering at model declaration level.
// Can be enabled on per query basis using scopes or repo options (if applicable)
// - `filter:"true"`: filter out "true" values, the default behavior
// Can be overridden on per query basis using scopes or repo options (if applicable)
// - `filter:"false"`: filter out "false" values
// Can be overridden on per query basis using scopes or repo options (if applicable)
// See SkipBoolFilter and BoolFiltering for filtering behaviour overriding
type FilterBool bool
// Value implements driver.Valuer
func (t FilterBool) Value() (driver.Value, error) {
return sql.NullBool{
Bool: bool(t),
Valid: true,
}.Value()
}
// Scan implements sql.Scanner
func (t *FilterBool) Scan(src interface{}) error {
nullBool := &sql.NullBool{}
if e := nullBool.Scan(src); e != nil {
return e
}
*t = FilterBool(nullBool.Valid && nullBool.Bool)
return nil
}
func (t FilterBool) GormDataType() string {
return "bool"
}
// QueryClauses implements schema.QueryClausesInterface,
func (t FilterBool) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{newBoolFilterClause(f)}
}
/****************************
Helpers
****************************/
// boolFilterClause implements clause.Interface and gorm.StatementModifier, where gorm.StatementModifier do the real work.
// See gorm.DeletedAt for impl. reference
type boolFilterClause struct {
NoopStatementModifier
FilterMode filterMode
Field *schema.Field
}
func newBoolFilterClause(f *schema.Field) clause.Interface {
mode := fmPositive
tag := strings.ToLower(strings.TrimSpace(f.Tag.Get(TagFilter)))
switch tag {
case "false":
mode = fmNegative
case "-":
mode = fmDisabled
}
return &boolFilterClause{
FilterMode: mode,
Field: f,
}
}
func (c boolFilterClause) ModifyStatement(stmt *gorm.Statement) {
mode := c.determineFilterMode(stmt.Context)
if mode == fmDisabled {
return
}
// special fix for db.Model(&model{}).Where(&model{f1:v1}).Or(&model{f2:v2})...
// Ref: https://github.com/go-gorm/gorm/issues/3627
// https://github.com/go-gorm/gorm/commit/9b2181199d88ed6f74650d73fa9d20264dd134c0#diff-e3e9193af67f3a706b3fe042a9f121d3609721da110f6a585cdb1d1660fd5a3c
FixWhereClausesForStatementModifier(stmt)
// add bool filtering
colExpr := stmt.Quote(clause.Column{Table: clause.CurrentTable, Name: c.Field.DBName})
unfilteredValue := mode != fmPositive
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Expr{
SQL: fmt.Sprintf("%s IS %s", colExpr, strconv.FormatBool(unfilteredValue)),
},
}})
}
/***********************
Helpers
***********************/
func (c boolFilterClause) determineFilterMode(ctx context.Context) filterMode {
if ctx == nil {
return c.FilterMode
}
if v, ok := ctx.Value(ckFilterMode("*")).(filterMode); ok {
return v
}
if v, ok := ctx.Value(ckFilterMode(c.Field.Name)).(filterMode); ok {
return v
}
return c.FilterMode
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"encoding"
"encoding/json"
"fmt"
"github.com/google/uuid"
"regexp"
"strconv"
"strings"
)
const (
v1Separator = ":"
)
var (
v1TextPrefix = fmt.Sprintf("%d%s", V1, v1Separator)
javaTypePattern, _ = regexp.Compile(`^[[:alpha:]][[:alnum:]]*(\.[[:alpha:]][[:alnum:]]*)+`)
jsonNull = `null`
)
/*************************
Parsing
*************************/
func ParseEncryptedRaw(text string) (ret *EncryptedRaw, err error) {
ret = &EncryptedRaw{}
// first try V1
//nolint:errorlint // special error is global var
switch e := ret.UnmarshalTextV1([]byte(text)); {
case e == nil:
return
case e != ErrUnsupportedVersion:
return nil, e
}
// try JSON format
if e := json.Unmarshal([]byte(text), ret); e != nil {
return nil, newInvalidFormatError("invalid V2 format - %v", e)
}
return
}
// UnmarshalTextV1 deserialize V1 format of text
func (d *EncryptedRaw) UnmarshalTextV1(text []byte) error {
str := string(text)
if !isV1Format(str) {
return ErrUnsupportedVersion
}
split := strings.SplitN(str, v1Separator, 4)
if len(split) < 4 {
return newInvalidFormatError("not V1 format")
}
var ver Version
if e := unmarshalText(split[0], &ver); e != nil {
return newInvalidFormatError("unsupported version")
}
kid:= split[1]
if _, e := uuid.Parse(kid); e != nil {
return newInvalidFormatError("invalid Key ID")
}
var alg Algorithm
if e := unmarshalText(split[2], &alg); e != nil {
return newInvalidFormatError("unsupported algorithm")
}
var raw json.RawMessage
switch alg {
case AlgPlain:
raw = json.RawMessage(split[3])
case AlgVault:
raw = json.RawMessage(strconv.Quote(split[3]))
}
if !json.Valid(raw) {
return newInvalidFormatError("unsupported raw data")
}
*d = EncryptedRaw{
Ver: ver,
KeyID: kid,
Alg: alg,
Raw: raw,
}
return nil
}
/*************************
V1 Plain Data
*************************/
type v1DecryptedData json.RawMessage
// UnmarshalJSON implements json.Unmarshaler with V1 support
// V1 (Java) format of unencrypted payload could be
// - a (T extends Map<String, String>) serialized by Jackson with `As.WRAPPER_ARRAY` option (JSON Array)
// - a (T extends Map>String, String>) serialized by Jackson without `As.WRAPPER_ARRAY` option (JSON Object
func (d *v1DecryptedData) UnmarshalJSON(data []byte) (err error) {
if len(data) == 0 {
*d = nil
return nil
}
switch data[0] {
case '[':
var s []json.RawMessage
if e := json.Unmarshal(data, &s); e != nil {
return e
}
// find first non-string, also check if string element is a Java type expr
var v json.RawMessage
switch len(s) {
case 1:
v = s[0]
case 2:
str := ""
if e := json.Unmarshal(s[0], &str); e != nil || !javaTypePattern.Match([]byte(str)) {
return ErrInvalidV1Format
}
v = s[1]
default:
return ErrInvalidV1Format
}
*d = v1DecryptedData(v)
case '{':
*d = data
default:
return ErrInvalidV1Format
}
return nil
}
/*************************
helpers
*************************/
func isV1Format(text string) bool {
return strings.HasPrefix(text, v1TextPrefix)
}
func unmarshalText(data string, v encoding.TextUnmarshaler) error {
return v.UnmarshalText([]byte(data))
}
// extractV1DecryptedPayload decode V1 (Java) format of unencrypted payload and convert it to object
// V1 format could be
// - a (T extends Map<String, String>) serialized by Jackson with `As.WRAPPER_ARRAY` option (JSON Array)
// - a (T extends Map>String, String>) serialized by Jackson without `As.WRAPPER_ARRAY` option (JSON Object
// - a JSON "null"
// - empty data
func extractV1DecryptedPayload(data []byte) (json.RawMessage, error) {
if len(data) == 0 || len(data) == 4 && string(data) == jsonNull {
// json null or nil/empty is considered nil
return json.RawMessage(jsonNull), nil
}
raw := v1DecryptedData{}
if e := json.Unmarshal(data, &raw); e != nil {
return nil, newInvalidFormatError("unencrypted data JSON parsing error - %v", e)
}
return json.RawMessage(raw), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"context"
"database/sql/driver"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/data/types/pqx"
"strings"
)
var (
ErrUnsupportedVersion = data.NewDataError(data.ErrorCodeOrmMapping, "unsupported version of encrypted data format")
ErrUnsupportedAlgorithm = data.NewDataError(data.ErrorCodeOrmMapping, "unsupported encryption algorithm of data")
ErrInvalidFormat = data.NewDataError(data.ErrorCodeOrmMapping, "invalid encrypted data")
ErrInvalidV1Format = data.NewDataError(data.ErrorCodeOrmMapping, "invalid V1 data payload format")
)
/*************************
Enums
*************************/
const (
// V1 is Java compatible data structure
V1 Version = 1
// V2 is Generic JSON version, default format of go-lanai
V2 Version = 2
defaultVersion = V2
minVersion = V2
v1Text = "1"
v2Text = "2"
)
type Version int
// UnmarshalText implements encoding.TextUnmarshaler
func (v *Version) UnmarshalText(text []byte) error {
str := strings.TrimSpace(string(text))
switch str {
case v1Text:
*v = V1
case v2Text:
*v = V2
case "":
*v = defaultVersion
default:
return fmt.Errorf("unknown encrypted data version: %s", str)
}
return nil
}
// UnmarshalJSON implements json.Unmarshaler with V1 support
func (v *Version) UnmarshalJSON(data []byte) (err error) {
var i int
if e := json.Unmarshal(data, &i); e != nil {
return e
}
switch i {
case int(V1), int(V2):
*v = Version(i)
case 0:
*v = defaultVersion
default:
return fmt.Errorf("unknown encrypted data version: %d", i)
}
return nil
}
const (
AlgPlain Algorithm = "p"
AlgVault Algorithm = "e" // this value is compatible with Java counterpart
defaultAlg = AlgPlain
)
type Algorithm string
// UnmarshalText implements encoding.TextUnmarshaler
func (a *Algorithm) UnmarshalText(text []byte) error {
str := strings.TrimSpace(string(text))
switch str {
case string(AlgPlain):
*a = AlgPlain
case string(AlgVault):
*a = AlgVault
case "":
*a = defaultAlg
default:
return fmt.Errorf("unknown encrypted data algorithm: %s", str)
}
return nil
}
/*************************
Data
*************************/
// EncryptedRaw is the carrier of encrypted data
// this data type implements gorm.Valuer, schema.GormDataTypeInterface
type EncryptedRaw struct {
Ver Version `json:"v"`
KeyID string `json:"kid,omitempty"`
Alg Algorithm `json:"alg,omitempty"`
Raw json.RawMessage `json:"d,omitempty"`
}
// GormDataType implements schema.GormDataTypeInterface
func (EncryptedRaw) GormDataType() string {
return "jsonb"
}
// Value implements driver.Valuer
func (d *EncryptedRaw) Value() (driver.Value, error) {
//we need to check nil here instead of in the JsonbValue method
//because the input to JsonbValue is interface{}. Since d has a type
// the v==nil check in JsonbValue won't return true
if d == nil {
return nil, nil
}
return pqx.JsonbValue(d)
}
// Scan implements sql.Scanner
func (d *EncryptedRaw) Scan(src interface{}) error {
return pqx.JsonbScan(src, d)
}
/*************************
Interface
*************************/
type Encryptor interface {
// Encrypt encrypt given "v" and populate EncryptedRaw.Raw
// The process may read EncryptedRaw.Alg and EncryptedRaw.KeyID and update EncryptedRaw.Ver
Encrypt(ctx context.Context, kid string, v interface{}) (*EncryptedRaw, error)
// Decrypt reads EncryptedRaw and populate the decrypted data into given "v"
// if v is not pointer type, this method may return error
Decrypt(ctx context.Context, raw *EncryptedRaw, dest interface{}) error
// KeyOperations returns an object that operates on keys.
// depending on configurations, this method may returns no-op impl, but never nil
KeyOperations() KeyOperations
}
type KeyOptions func(opt *keyOption)
type keyOption struct {
ktype string
exportable bool
allowPlaintextBk bool
}
type KeyOperations interface {
// Create create keys with given key ID.
// Note: KeyOptions is for future support, it's currently ignored
Create(ctx context.Context, kid string, opts ...KeyOptions) error
}
/*************************
Common
*************************/
type compositeEncryptor []Encryptor
// Encrypt always uses first Encryptor
func (enc compositeEncryptor) Encrypt(ctx context.Context, kid string, v interface{}) (*EncryptedRaw, error) {
if len(enc) != 0 {
return enc[0].Encrypt(ctx, kid, v)
}
return nil, newEncryptionError("encryptor is not properly configured")
}
func (enc compositeEncryptor) Decrypt(ctx context.Context, raw *EncryptedRaw, dest interface{}) error {
for _, delegate := range enc {
e := delegate.Decrypt(ctx, raw, dest)
switch e { //nolint:errorlint // special error is global var
case nil:
return nil
case ErrUnsupportedAlgorithm, ErrUnsupportedVersion:
continue
default:
return e
}
}
return newDecryptionError("encryptor is not available for ver=%d and alg=%v", raw.Ver, raw.Alg)
}
func (enc compositeEncryptor) KeyOperations() KeyOperations {
ret := make(compositeKeyOperations, 0, len(enc))
for _, delegate := range enc {
ops := delegate.KeyOperations()
if ops == noopKeyOps {
continue
}
ret = append(ret, ops)
}
return ret
}
type compositeKeyOperations []KeyOperations
func (o compositeKeyOperations) Create(ctx context.Context, kid string, opts ...KeyOptions) error {
for _, ops := range o {
if e := ops.Create(ctx, kid, opts...); e != nil {
return e
}
}
return nil
}
type noopKeyOperations struct{}
var noopKeyOps = noopKeyOperations{}
func (o noopKeyOperations) Create(_ context.Context, _ string, _ ...KeyOptions) error {
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"context"
"encoding/json"
)
type plainTextEncryptor struct{}
func (enc plainTextEncryptor) Encrypt(_ context.Context, kid string, v interface{}) (raw *EncryptedRaw, err error) {
raw = &EncryptedRaw{
Ver: V2,
KeyID: kid,
Alg: AlgPlain,
}
switch {
case raw.KeyID == "":
return nil, newEncryptionError("KeyID is required for algorithm %v", raw.Alg)
}
data, e := json.Marshal(v)
if e != nil {
return nil, newEncryptionError("cannot marshal data to JSON - %v", e)
}
raw.Raw = data
return
}
func (enc plainTextEncryptor) Decrypt(ctx context.Context, raw *EncryptedRaw, dest interface{}) error {
if raw == nil {
return newDecryptionError("raw data is nil")
}
switch raw.Ver {
case V1, V2:
return enc.decrypt(ctx, raw, dest)
default:
return ErrUnsupportedVersion
}
}
func (enc plainTextEncryptor) KeyOperations() KeyOperations {
return noopKeyOps
}
func (enc plainTextEncryptor) decrypt(_ context.Context, raw *EncryptedRaw, dest interface{}) error {
if raw.Alg != AlgPlain {
return ErrUnsupportedAlgorithm
}
switch raw.Ver {
case V1:
v, e := extractV1DecryptedPayload(raw.Raw)
if e != nil {
return newDecryptionError("malformed V1 data - %v", e)
}
if e := json.Unmarshal(v, dest); e != nil {
return newDecryptionError("failed to unmarshal decrypted data - %v", e)
}
case V2:
if e := json.Unmarshal(raw.Raw, dest); e != nil {
return newDecryptionError("failed to unmarshal decrypted data - %v", e)
}
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"context"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/vault"
"strconv"
)
// vaultEncryptor implements Encryptor and KeyOperations
type vaultEncryptor struct {
transit vault.TransitEngine
props *KeyProperties
}
func newVaultEncryptor(client *vault.Client, props *KeyProperties) Encryptor {
return &vaultEncryptor{
transit: vault.NewTransitEngine(client, func(opt *vault.KeyOption) {
opt.KeyType = props.Type
opt.Exportable = props.Exportable
opt.AllowPlaintextBackup = props.AllowPlaintextBackup
}),
props: props,
}
}
func (enc *vaultEncryptor) Encrypt(ctx context.Context, kid string, v interface{}) (raw *EncryptedRaw, err error) {
raw = &EncryptedRaw{
Ver: V2,
KeyID: normalizeKeyID(kid),
Alg: AlgVault,
}
switch {
case raw.KeyID == "":
return nil, newEncryptionError("KeyID is required for algorithm %v", raw.Alg)
}
if v == nil {
// special rule encrypted []byte(nil) <-> nil
return raw, nil
}
jsonVal, e := json.Marshal(v)
if e != nil {
return nil, newEncryptionError("failed to marshal data - %v", e)
}
cipher, e := enc.transit.Encrypt(ctx, raw.KeyID, jsonVal)
if e != nil {
return nil, newEncryptionError("encryption engine - %v", e)
}
raw.Raw = json.RawMessage(strconv.Quote(string(cipher)))
return
}
func (enc *vaultEncryptor) Decrypt(ctx context.Context, raw *EncryptedRaw, dest interface{}) error {
switch {
case raw == nil:
return newDecryptionError("raw data is nil")
case raw.Alg != AlgVault:
return ErrUnsupportedAlgorithm
case raw.KeyID == "":
return newDecryptionError("KeyID is required for algorithm %v", raw.Alg)
}
switch raw.Ver {
case V1, V2:
return enc.decrypt(ctx, raw, dest)
default:
return ErrUnsupportedVersion
}
}
func (enc *vaultEncryptor) KeyOperations() KeyOperations {
return enc
}
/* KeyOperations */
func (enc *vaultEncryptor) Create(ctx context.Context, kid string, _ ...KeyOptions) error {
kid = normalizeKeyID(kid)
if kid == "" {
return fmt.Errorf("invalid key ID")
}
return enc.transit.PrepareKey(ctx, kid)
}
/* Helpers */
func (enc *vaultEncryptor) decrypt(ctx context.Context, raw *EncryptedRaw, dest interface{}) error {
if len(raw.Raw) == 0 {
// special rule encrypted []byte(nil) <-> nil
return tryAssign(nil, dest)
}
var cipher string
if e := json.Unmarshal(raw.Raw, &cipher); e != nil {
return newDecryptionError("invalid ciphertext - %v", e)
}
plain, e := enc.transit.Decrypt(ctx, normalizeKeyID(raw.KeyID), []byte(cipher))
if e != nil {
return newDecryptionError("encryption engine - %v", e)
}
switch raw.Ver {
case V1:
v, e := extractV1DecryptedPayload(plain)
if e != nil {
return newDecryptionError("malformed V1 data - %v", e)
}
if e := json.Unmarshal(v, dest); e != nil {
return newDecryptionError("failed to unmarshal decrypted data - %v", e)
}
case V2:
if e := json.Unmarshal(plain, dest); e != nil {
return newDecryptionError("failed to unmarshal decrypted data - %v", e)
}
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"context"
"github.com/google/uuid"
)
const (
errTmplNotConfigured = `data encryption is not properly configured`
)
var encryptor Encryptor = plainTextEncryptor{}
var zeroUUID = uuid.UUID{}
// Encrypt is a package level API that wraps shared Encryptor.Encrypt
func Encrypt(ctx context.Context, kid string, v interface{}) (*EncryptedRaw, error) {
if encryptor == nil {
return nil, newEncryptionError(errTmplNotConfigured)
}
return encryptor.Encrypt(ctx, kid, v)
}
// Decrypt is a package level API that wraps shared Encryptor.Decrypt
func Decrypt(ctx context.Context, raw *EncryptedRaw, dest interface{}) error {
if encryptor == nil {
return newDecryptionError(errTmplNotConfigured)
}
return encryptor.Decrypt(ctx, raw, dest)
}
// CreateKey create keys with given key ID.
// Note: KeyOptions is for future support, it's currently ignored
func CreateKey(ctx context.Context, kid string, opts ...KeyOptions) error {
if encryptor == nil {
return newEncryptionError(errTmplNotConfigured)
}
return encryptor.KeyOperations().Create(ctx, kid, opts...)
}
// CreateKeyWithUUID create keys with given key ID.
// Note: KeyOptions is for future support, it's currently ignored
func CreateKeyWithUUID(ctx context.Context, kid uuid.UUID, opts ...KeyOptions) error {
if kid == zeroUUID {
return CreateKey(ctx, "", opts...)
}
return CreateKey(ctx, kid.String(), opts...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"context"
"database/sql/driver"
"github.com/google/uuid"
)
type EncryptedMap struct {
EncryptedRaw
Data map[string]interface{} `json:"-"`
}
func NewEncryptedMap(kid uuid.UUID, v map[string]interface{}) *EncryptedMap {
if kid == zeroUUID {
return newEncryptedMap("", v)
}
return newEncryptedMap(kid.String(), v)
}
func newEncryptedMap(kid string, v map[string]interface{}) *EncryptedMap {
return &EncryptedMap{
EncryptedRaw: EncryptedRaw{
KeyID: kid,
},
Data: v,
}
}
// Value implements driver.Valuer
func (d *EncryptedMap) Value() (driver.Value, error) {
raw, e := Encrypt(context.Background(), d.KeyID, d.Data)
if e != nil {
return nil, e
}
d.EncryptedRaw = *raw
return d.EncryptedRaw.Value()
}
// Scan implements sql.Scanner
func (d *EncryptedMap) Scan(src interface{}) error {
if e := d.EncryptedRaw.Scan(src); e != nil {
return e
}
return Decrypt(context.Background(), &d.EncryptedRaw, &d.Data)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"embed"
"fmt"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/vault"
"go.uber.org/fx"
)
//var logger = log.New("Data.Enc")
//go:embed defaults-data-enc.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "data-encryption",
Precedence: bootstrap.DatabasePrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(BindDataEncryptionProperties, provideEncryptor),
fx.Invoke(initialize),
},
}
func Use() {
bootstrap.Register(Module)
}
/**************************
Provider
***************************/
type encDI struct {
fx.In
Properties DataEncryptionProperties `optional:"true"`
Client *vault.Client `optional:"true"`
UnnamedEnc Encryptor `optional:"true"`
}
type encOut struct {
fx.Out
Enc Encryptor `name:"data/Encryptor"`
}
func provideEncryptor(di encDI) encOut {
if di.UnnamedEnc != nil {
return encOut{
Enc: di.UnnamedEnc,
}
}
var enc Encryptor
switch {
case di.Properties.Enabled:
if di.Client == nil {
panic(fmt.Errorf("data encryption enabled but vault client is not initialized"))
}
venc := newVaultEncryptor(di.Client, &di.Properties.Key)
enc = compositeEncryptor{venc, plainTextEncryptor{}}
default:
enc = plainTextEncryptor{}
}
return encOut{
Enc: enc,
}
}
/**************************
Initialize
***************************/
type initDI struct {
fx.In
Enc Encryptor `name:"data/Encryptor"`
}
func initialize(di initDI) {
encryptor = di.Enc
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"strings"
)
const (
PropertiesPrefix = "data.encryption"
)
type DataEncryptionProperties struct {
Enabled bool `json:"enabled"`
Key KeyProperties `json:"key"`
}
type KeyProperties struct {
Type string `json:"type"`
Exportable bool `json:"exportable"`
AllowPlaintextBackup bool `json:"allow-plaintext-backup"`
}
// https://www.vaultproject.io/api/secret/transit#create-key
const (
KeyTypeAES128 = "aes128-gcm96"
KeyTypeAES256 = "aes256-gcm96"
KeyTypeChaCha20 = "chacha20-poly1305"
KeyTypeED25519 = "ed25519"
KeyTypeECDSA256 = "ecdsa-p256"
KeyTypeECDSA384 = "ecdsa-p384"
KeyTypeECDSA521 = "ecdsa-p521"
KeyTypeRSA2048 = "rsa-2048"
KeyTypeRSA3072 = "rsa-3072"
KeyTypeRSA4096 = "rsa-4096"
defaultKeyType = KeyTypeAES256
)
var supportedKeyTypes = utils.NewStringSet(
KeyTypeAES128, KeyTypeAES256, KeyTypeChaCha20,
KeyTypeED25519, KeyTypeECDSA256, KeyTypeECDSA384, KeyTypeECDSA521,
KeyTypeRSA2048, KeyTypeRSA3072, KeyTypeRSA4096,
)
type KeyType string
// UnmarshalText implements encoding.TextUnmarshaler
func (t *KeyType) UnmarshalText(text []byte) error {
str := strings.ToLower(strings.TrimSpace(string(text)))
switch {
case len(str) == 0:
*t = defaultKeyType
case supportedKeyTypes.Has(str):
*t = KeyType(str)
default:
return fmt.Errorf("unknown encryption key type: %s", str)
}
return nil
}
//NewDataEncryptionProperties create a CockroachProperties with default values
func NewDataEncryptionProperties() *DataEncryptionProperties {
return &DataEncryptionProperties{
Enabled: false,
Key: KeyProperties{
Type: defaultKeyType,
Exportable: false,
AllowPlaintextBackup: false,
},
}
}
//BindDataEncryptionProperties create and bind SessionProperties, with a optional prefix
func BindDataEncryptionProperties(ctx *bootstrap.ApplicationContext) DataEncryptionProperties {
props := NewDataEncryptionProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind DataEncryptionProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqcrypt
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"reflect"
"strings"
)
func newInvalidFormatError(text string, args...interface{}) error {
return data.NewDataError(data.ErrorCodeOrmMapping, "invalid encrypted data: " + fmt.Sprintf(text, args...))
}
func newEncryptionError(text string, args...interface{}) error {
return data.NewDataError(data.ErrorCodeOrmMapping, "failed to encrypt data: " + fmt.Sprintf(text, args...))
}
func newDecryptionError(text string, args...interface{}) error {
return data.NewDataError(data.ErrorCodeOrmMapping, "failed to decrypt data: " + fmt.Sprintf(text, args...))
}
func normalizeKeyID(kid string) string {
return strings.ToLower(kid)
}
func tryAssign(v interface{}, dest interface{}) (err error) {
defer func() {
if e := recover(); e != nil {
err = newDecryptionError("recovered: %v", e)
}
}()
// check
rDest := reflect.ValueOf(dest)
if rDest.Kind() != reflect.Ptr {
return newDecryptionError("%T is not assignable", dest)
}
rDest = rDest.Elem()
if !rDest.CanSet() {
return newDecryptionError("%T is not assignable", dest)
}
// assign
if v == nil {
rDest.Set(reflect.New(rDest.Type()).Elem())
return nil
}
rv := reflect.ValueOf(v)
if !rv.Type().AssignableTo(rDest.Type()) {
return newDecryptionError("decrypted data type mismatch, expect %T, but got %T", rDest.Interface(), v)
}
rDest.Set(rv)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqx
import (
"database/sql/driver"
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
// Duration is also an alias of time.Duration
type Duration utils.Duration
// Value implements driver.Valuer
func (d Duration) Value() (driver.Value, error) {
return time.Duration(d).String(), nil
}
// Scan implements sql.Scanner
func (d *Duration) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
*d = Duration(utils.ParseDuration(string(src)))
case string:
*d = Duration(utils.ParseDuration(src))
case int, int8, int16, int32, int64:
// TODO review how convert numbers to Duration
*d = Duration(src.(int64))
case nil:
return nil
default:
return data.NewDataError(data.ErrorCodeOrmMapping,
fmt.Sprintf("pqx: unable to convert data type %T to Duration", src))
}
return nil
}
// MarshalText implements encoding.TextMarshaler
func (d Duration) MarshalText() (text []byte, err error) {
return utils.Duration(d).MarshalText()
}
// UnmarshalText implements encoding.TextUnmarshaler
func (d *Duration) UnmarshalText(text []byte) error {
return (*utils.Duration)(d).UnmarshalText(text)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqx
import (
"database/sql/driver"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
)
// JsonbScan helps models to implement sql.Scanner
func JsonbScan(src interface{}, v interface{}) error {
var d []byte
switch src.(type) {
case string:
d = []byte(src.(string))
case []byte:
d = src.([]byte)
case nil:
return nil
default:
msg := fmt.Sprintf("unable to scan %T as JSONB format", src)
return data.NewDataError(data.ErrorCodeOrmMapping, msg)
}
if e := json.Unmarshal(d, v); e != nil {
return data.NewDataError(data.ErrorCodeOrmMapping, fmt.Sprintf("unable to scan JSONB into %T: %v", v, e), e)
}
return nil
}
// JsonbValue helps models to implement driver.Valuer
func JsonbValue(v interface{}) (driver.Value, error) {
if v == nil {
return nil, nil
}
d, e := json.Marshal(v)
if e != nil {
return nil, data.NewDataError(data.ErrorCodeInvalidSQL, fmt.Sprintf("unable to convert %T to JSONB: %v", v, e), e)
}
return string(d), nil
}
type JsonbMap map[string]interface{}
func (m JsonbMap) Value() (driver.Value, error) {
return JsonbValue(m)
}
func (m *JsonbMap) Scan(src interface{}) error {
return JsonbScan(src, m)
}
type JsonbStringMap map[string]string
func (m JsonbStringMap) Value() (driver.Value, error) {
return JsonbValue(m)
}
func (m *JsonbStringMap) Scan(src interface{}) error {
return JsonbScan(src, m)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqx
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/reflectutils"
"github.com/google/uuid"
"gorm.io/gorm"
"reflect"
)
const (
fieldTenantID = "TenantID"
fieldTenantPath = "TenantPath"
colTenantID = "tenant_id"
colTenantPath = "tenant_path"
)
var (
typeUUID = reflect.TypeOf(uuid.Nil)
typeTenantPath = reflect.TypeOf(TenantPath{})
typeTenancy = reflect.TypeOf(Tenancy{})
typeTenancyPtr = reflect.TypeOf(&Tenancy{})
mapKeysTenantID = utils.NewStringSet(fieldTenantID, colTenantID)
mapKeysTenantPath = utils.NewStringSet(fieldTenantPath, colTenantPath)
)
type ckTenancyCheckMode struct{}
const (
TenancyCheckFlagWriteValueCheck TenancyCheckFlag = 1 << iota
TenancyCheckFlagWriteFiltering
TenancyCheckFlagReadFiltering
)
// TenancyCheckFlag bitwise Flag of tenancy flag mode
type TenancyCheckFlag uint
const (
tcModeDefault = tcMode(TenancyCheckFlagWriteFiltering | TenancyCheckFlagWriteValueCheck)
)
// tcMode enum of tenancyCheckMode
type tcMode uint
func (m tcMode) hasFlags(flags ...TenancyCheckFlag) bool {
for _, flag := range flags {
if m&tcMode(flag) == 0 {
return false
}
}
return true
}
// SkipTenancyCheck is used as a scope for gorm.DB to skip tenancy check
// e.g. db.WithContext(ctx).Scopes(SkipTenancyCheck()).Find(...)
// Note using this scope without context would panic
func SkipTenancyCheck() func(*gorm.DB) *gorm.DB {
return TenancyCheck(0)
}
// TenancyCheck is used as a scope for gorm.DB to override tenancy check
// e.g. db.WithContext(ctx).Scopes(TenancyCheck()).Find(...)
// Note using this scope without context would panic
func TenancyCheck(flags ...TenancyCheckFlag) func(*gorm.DB) *gorm.DB {
return func(tx *gorm.DB) *gorm.DB {
if tx.Statement.Context == nil {
panic("SkipTenancyCheck used without context")
}
var mode tcMode
for _, flag := range flags {
mode = mode | tcMode(flag)
}
ctx := context.WithValue(tx.Statement.Context, ckTenancyCheckMode{}, mode)
tx.Statement.Context = ctx
return tx
}
}
// Tenancy is an embedded type for data model. It's responsible for populating TenantPath and check for Tenancy related data
// when crating/updating. Tenancy implements
// - callbacks.BeforeCreateInterface
// - callbacks.BeforeUpdateInterface
// When used as an embedded type, tag `filter` can be used to override default tenancy check behavior:
// - `filter:"w"`: create/update/delete are enforced (Default mode)
// - `filter:"rw"`: CRUD operations are all enforced,
// this mode filters result of any Select/Update/Delete query based on current security context
// - `filter:"-"`: filtering is disabled. Note: setting TenantID to in-accessible tenant is still enforced.
// to disable TenantID value check, use SkipTenancyCheck
//
// e.g.
// <code>
//
// type TenancyModel struct {
// ID uuid.UUID `gorm:"primaryKey;type:uuid;default:gen_random_uuid();"`
// Tenancy `filter:"rw"`
// }
//
// </code>
type Tenancy struct {
TenantID uuid.UUID `gorm:"type:KeyID;not null"`
TenantPath TenantPath `gorm:"type:uuid[];index:,type:gin;not null" json:"-"`
}
// SkipTenancyCheck is used for embedding models to override tenancy check behavior.
// It should be called within model's hooks. this function would panic if context is not set yet
func (Tenancy) SkipTenancyCheck(tx *gorm.DB) {
SkipTenancyCheck()(tx)
}
func (t *Tenancy) BeforeCreate(tx *gorm.DB) error {
//if tenantId is not available
if t.TenantID == uuid.Nil {
return errors.New("tenantId is required")
}
if !shouldSkip(tx.Statement.Context, TenancyCheckFlagWriteValueCheck, tcModeDefault) && !security.HasAccessToTenant(tx.Statement.Context, t.TenantID.String()) {
return errors.New(fmt.Sprintf("user does not have access to tenant %s", t.TenantID.String()))
}
path, err := tenancy.GetTenancyPath(tx.Statement.Context, t.TenantID.String())
if err == nil {
t.TenantPath = path
}
return err
}
// BeforeUpdate Check if user is allowed to update this item's tenancy to the target tenant.
// (i.e. if user has access to the target tenant)
// We don't check the original tenancy because we don't have that information in this hook. That check has to be done
// in application code.
func (t *Tenancy) BeforeUpdate(tx *gorm.DB) error {
dest := tx.Statement.Dest
tenantId, e := t.extractTenantId(tx.Statement.Context, dest)
if e != nil || tenantId == uuid.Nil {
return e
}
if !shouldSkip(tx.Statement.Context, TenancyCheckFlagWriteValueCheck, tcModeDefault) && !security.HasAccessToTenant(tx.Statement.Context, tenantId.String()) {
return errors.New(fmt.Sprintf("user does not have access to tenant %s", tenantId.String()))
}
path, err := tenancy.GetTenancyPath(tx.Statement.Context, tenantId.String())
if err == nil {
err = t.updateTenantPath(tx.Statement.Context, dest, path)
}
return err
}
func (t Tenancy) extractTenantId(_ context.Context, dest interface{}) (uuid.UUID, error) {
v := reflect.ValueOf(dest)
for ; v.Kind() == reflect.Ptr; v = v.Elem() {
// SuppressWarnings go:S108 empty block is intended
}
switch v.Kind() {
case reflect.Map:
if v.Type().Key().Kind() != reflect.String {
return uuid.Nil, fmt.Errorf("unsupported gorm update target type [%T], please use struct ptr, struct or map", dest)
}
if _, ev, ok := t.findMapValue(v, mapKeysTenantID, typeUUID); ok {
return ev.Interface().(uuid.UUID), nil
}
case reflect.Struct:
_, fv, ok := t.findStructField(v, fieldTenantID, typeUUID)
if ok {
return fv.Interface().(uuid.UUID), nil
}
default:
return uuid.Nil, fmt.Errorf("unsupported gorm update target type [%T], please use struct ptr, struct or map", dest)
}
return uuid.Nil, nil
}
func (t *Tenancy) updateTenantPath(_ context.Context, dest interface{}, tenancyPath TenantPath) error {
v := reflect.ValueOf(dest)
if v.Kind() == reflect.Struct {
return fmt.Errorf("cannot update tenancy automatically to %T, please use struct ptr or map", dest)
}
for ; v.Kind() == reflect.Ptr; v = v.Elem() {
// SuppressWarnings go:S108 empty block is intended
}
switch v.Kind() {
case reflect.Map:
if v.Type().Key().Kind() != reflect.String {
return fmt.Errorf("cannot update tenancy automatically with gorm update target type [%T], please use struct ptr or map", dest)
}
ek, ev, ok := t.findMapValue(v, mapKeysTenantPath, typeTenantPath)
// Note: if tenant path is explicitly set and correct, we don't change it
switch {
case ok && !reflect.DeepEqual(ev.Interface(), tenancyPath):
return fmt.Errorf("incorrect %s was set to gorm update target map", ek)
case !ok:
v.SetMapIndex(reflect.ValueOf(fieldTenantPath), reflect.ValueOf(tenancyPath))
}
case reflect.Struct:
if _, fv, ok := t.findStructField(v, fieldTenantPath, typeTenantPath); ok {
fv.Set(reflect.ValueOf(tenancyPath))
}
default:
return errors.New("cannot update tenancy automatically, please use struct ptr or map as gorm update target value")
}
return nil
}
func (Tenancy) findStructField(sv reflect.Value, name string, ft reflect.Type) (f reflect.StructField, fv reflect.Value, ok bool) {
f, ok = reflectutils.FindStructField(sv.Type(), func(t reflect.StructField) bool {
return t.Name == name && ft.AssignableTo(t.Type)
})
if ok {
fv = sv.FieldByIndex(f.Index)
}
return
}
func (Tenancy) findMapValue(mv reflect.Value, keys utils.StringSet, ft reflect.Type) (string, reflect.Value, bool) {
for iter := mv.MapRange(); iter.Next(); {
k := iter.Key().String()
if !keys.Has(k) {
continue
}
v := iter.Value()
if !v.IsZero() && ft.AssignableTo(v.Type()) {
return k, v, true
}
}
return "", reflect.Value{}, false
}
func shouldSkip(ctx context.Context, flag TenancyCheckFlag, fallback tcMode) bool {
if ctx == nil || !security.IsFullyAuthenticated(security.Get(ctx)) {
return true
}
switch v := ctx.Value(ckTenancyCheckMode{}).(type) {
case tcMode:
return !v.hasFlags(flag)
default:
return !fallback.hasFlags(flag)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqx
import (
"database/sql/driver"
"fmt"
securityinternal "github.com/cisco-open/go-lanai/internal/security"
"github.com/cisco-open/go-lanai/pkg/data/types"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils/reflectutils"
"github.com/google/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"reflect"
"strings"
)
/****************************
Func
****************************/
/****************************
Types
****************************/
// TenantPath implements
// - schema.GormDataTypeInterface
// - schema.QueryClausesInterface
// - schema.UpdateClausesInterface
// - schema.DeleteClausesInterface
// this data type adds "WHERE" clause for tenancy filtering
type TenantPath UUIDArray
// Value implements driver.Valuer
func (t TenantPath) Value() (driver.Value, error) {
return UUIDArray(t).Value()
}
// Scan implements sql.Scanner
func (t *TenantPath) Scan(src interface{}) error {
return (*UUIDArray)(t).Scan(src)
}
func (t TenantPath) GormDataType() string {
return "uuid[]"
}
// QueryClauses implements schema.QueryClausesInterface,
func (t TenantPath) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{newTenancyFilterClause(f, true)}
}
// UpdateClauses implements schema.UpdateClausesInterface,
func (t TenantPath) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{newTenancyFilterClause(f, false)}
}
// DeleteClauses implements schema.DeleteClausesInterface,
func (t TenantPath) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{newTenancyFilterClause(f, false)}
}
// tenancyFilterClause implements clause.Interface and gorm.StatementModifier, where gorm.StatementModifier do the real work.
// See gorm.DeletedAt for impl. reference
type tenancyFilterClause struct {
types.NoopStatementModifier
Flag TenancyCheckFlag
Mode tcMode
Field *schema.Field
}
func newTenancyFilterClause(f *schema.Field, isRead bool) *tenancyFilterClause {
mode := tcMode(TenancyCheckFlagWriteValueCheck)
tag := extractTenancyFilterTag(f)
switch tag {
case "":
mode = tcModeDefault
case "-":
default:
if strings.ContainsRune(tag, 'r') {
mode = mode | tcMode(TenancyCheckFlagReadFiltering)
}
if strings.ContainsRune(tag, 'w') {
mode = mode | tcMode(TenancyCheckFlagWriteFiltering)
}
}
flag := TenancyCheckFlagWriteFiltering
if isRead {
flag = TenancyCheckFlagReadFiltering
}
return &tenancyFilterClause{
Flag: flag,
Mode: mode,
Field: f,
}
}
func extractTenancyFilterTag(f *schema.Field) string {
if tag, ok := f.Tag.Lookup(types.TagFilter); ok {
return strings.ToLower(strings.TrimSpace(tag))
}
// check if tag is available on embedded Tenancy
sf, ok := reflectutils.FindStructField(f.Schema.ModelType, func(t reflect.StructField) bool {
return t.Anonymous && (t.Type.AssignableTo(typeTenancy) || t.Type.AssignableTo(typeTenancyPtr))
})
if ok {
return sf.Tag.Get(types.TagFilter)
}
return ""
}
func (c tenancyFilterClause) ModifyStatement(stmt *gorm.Statement) {
if shouldSkip(stmt.Context, c.Flag, c.Mode) {
return
}
tenantIDs := requiredTenancyFiltering(stmt)
if len(tenantIDs) == 0 {
return
}
// special fix for db.Model(&model{}).Where(&model{f1:v1}).Or(&model{f2:v2})...
// Ref: https://github.com/go-gorm/gorm/issues/3627
// https://github.com/go-gorm/gorm/commit/9b2181199d88ed6f74650d73fa9d20264dd134c0#diff-e3e9193af67f3a706b3fe042a9f121d3609721da110f6a585cdb1d1660fd5a3c
types.FixWhereClausesForStatementModifier(stmt)
// add tenancy filter condition
colExpr := stmt.Quote(clause.Column{Table: clause.CurrentTable, Name: c.Field.DBName})
sql := fmt.Sprintf("%s @> ?", colExpr)
var conditions []clause.Expression
for _, id := range tenantIDs {
conditions = append(conditions, clause.Expr{
SQL: sql,
Vars: []interface{}{UUIDArray{id}},
})
}
if len(conditions) == 1 {
stmt.AddClause(clause.Where{Exprs: conditions})
} else {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conditions...)}})
}
}
/***********************
Helpers
***********************/
func requiredTenancyFiltering(stmt *gorm.Statement) (tenantIDs []uuid.UUID) {
auth := security.Get(stmt.Context)
ta, _ := auth.Details().(securityinternal.TenantAccessDetails)
if ta != nil {
idsStr := ta.EffectiveAssignedTenantIds()
if idsStr.Has(security.SpecialTenantIdWildcard) {
return nil
}
tenantIDs = make([]uuid.UUID, 0, len(idsStr))
for tenant := range idsStr {
if tenantId, e := uuid.Parse(tenant); e == nil {
tenantIDs = append(tenantIDs, tenantId)
}
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqx
import (
"bytes"
"database/sql/driver"
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/lib/pq"
"time"
)
// TimeArray register driver.Valuer & sql.Scanner
type TimeArray []time.Time
// Value implements driver.Valuer
func (a TimeArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
n := len(a)
if n <= 0 {
return "{}", nil
}
// There will be at least two curly brackets, 2*N bytes of quotes,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+3*n)
b[0] = '{'
b = appendArrayQuotedBytes(b, pq.FormatTimestamp(a[0]))
for i := 1; i < n; i++ {
b = append(b, ',')
b = appendArrayQuotedBytes(b, pq.FormatTimestamp(a[i]))
}
return string(append(b, '}')), nil
}
// Scan implements sql.Scanner
func (a *TimeArray) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}
return fmt.Errorf("pqx: cannot convert %T to TimeArray", src)
}
func (a *TimeArray) scanBytes(src []byte) error {
var strs pq.StringArray
sPtr := &strs
if e := sPtr.Scan(src); e != nil {
return data.NewDataError(data.ErrorCodeOrmMapping, e)
}
elems := make(TimeArray, len(strs))
for i, s := range strs {
t, e := pq.ParseTimestamp(time.UTC, s)
if e != nil {
return data.NewDataError(data.ErrorCodeOrmMapping,
fmt.Sprintf("pqx: parsing array at idx %d: %v", i, e.Error()), e)
}
elems[i] = t
}
*a = elems
return nil
}
func appendArrayQuotedBytes(b, v []byte) []byte {
b = append(b, '"')
for {
i := bytes.IndexAny(v, `"\`)
if i < 0 {
b = append(b, v...)
break
}
if i > 0 {
b = append(b, v[:i]...)
}
b = append(b, '\\', v[i])
v = v[i+1:]
}
return append(b, '"')
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package pqx
import (
"database/sql/driver"
"fmt"
"github.com/google/uuid"
"github.com/lib/pq"
)
type UUIDArray []uuid.UUID
// Value implements driver.Valuer
func (a UUIDArray) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
return pq.StringArray(a.Strings()).Value()
}
// Scan implements sql.Scanner
func (a *UUIDArray) Scan(src interface{}) error {
if a == nil {
return nil
}
strArray := &pq.StringArray{}
if e := strArray.Scan(src); e != nil {
return e
}
uuids := make(UUIDArray, len(*strArray))
for i, v := range *strArray {
var e error
if uuids[i], e = uuid.Parse(v); e != nil {
return fmt.Errorf("pq: cannot convert %T to UUIDArray - %v", src, e)
}
}
*a = uuids
return nil
}
func (a UUIDArray) Strings() []string {
strArray := make(pq.StringArray, len(a))
for i, v := range a{
strArray[i] = v.String()
}
return strArray
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulsd
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/consul"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/log"
"sync"
)
type ClientOptions func(opt *ClientConfig)
type ClientConfig struct {
Logger log.ContextualLogger
Verbose bool
DefaultSelector discovery.InstanceMatcher
}
type consulDiscoveryClient struct {
ctx context.Context
conn *consul.Connection
instancers map[string]*Instancer
mutex sync.Mutex
config ClientConfig
}
func NewDiscoveryClient(ctx context.Context, conn *consul.Connection, opts ...ClientOptions) discovery.Client {
if ctx == nil {
panic("creating ConsulDiscoveryClient with nil context")
}
client := consulDiscoveryClient{
ctx: ctx,
conn: conn,
instancers: map[string]*Instancer{},
config: ClientConfig{
Logger: logger,
Verbose: false,
},
}
for _, fn := range opts {
fn(&client.config)
}
return &client
}
func (c *consulDiscoveryClient) Context() context.Context {
return c.ctx
}
func (c *consulDiscoveryClient) Instancer(serviceName string) (discovery.Instancer, error) {
if serviceName == "" {
return nil, fmt.Errorf("empty service name")
}
c.mutex.Lock()
defer c.mutex.Unlock()
instancer, ok := c.instancers[serviceName]
if ok {
return instancer, nil
}
instancer = NewInstancer(c.ctx, func(opt *InstancerOption) {
opt.ConsulConnection = c.conn
opt.Name = serviceName
opt.Logger = c.config.Logger
opt.Verbose = c.config.Verbose
opt.Selector = c.config.DefaultSelector
})
c.instancers[serviceName] = instancer
return instancer, nil
}
func (c *consulDiscoveryClient) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
for _, v := range c.instancers {
v.Stop()
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulsd
import (
"context"
"github.com/cisco-open/go-lanai/pkg/consul"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/discovery/sd"
"github.com/cisco-open/go-lanai/pkg/utils/loop"
"github.com/hashicorp/consul/api"
"sort"
"time"
)
const (
defaultIndex uint64 = 0
)
type InstancerOptions func(opt *InstancerOption)
type InstancerOption struct {
sd.InstancerOption
Selector discovery.InstanceMatcher
ConsulConnection *consul.Connection
}
// Instancer implements discovery.Instancer
// It yields service for a serviceName in Consul.
// See discovery.Instancer
type Instancer struct {
sd.CachedInstancer
consul *consul.Connection
lastMeta *api.QueryMeta
selector discovery.InstanceMatcher
}
// NewInstancer returns a discovery.Instancer with Consul service discovery APIs.
// See discovery.Instancer
func NewInstancer(ctx context.Context, opts ...InstancerOptions) *Instancer {
opt := InstancerOption{
InstancerOption: sd.InstancerOption{
Logger: logger,
RefresherOptions: []loop.TaskOptions{
loop.ExponentialRepeatIntervalOnError(50*time.Millisecond, sd.DefaultRefreshBackoffFactor),
},
},
}
for _, f := range opts {
f(&opt)
}
i := &Instancer{
CachedInstancer: sd.MakeCachedInstancer(func(baseOpt *sd.CachedInstancerOption) {
baseOpt.InstancerOption = opt.InstancerOption
}),
consul: opt.ConsulConnection,
selector: opt.Selector,
}
i.BackgroundRefreshFunc = i.resolveInstancesTask()
i.Start(ctx)
return i
}
func (i *Instancer) resolveInstancesTask() func(ctx context.Context) (*discovery.Service, error) {
// Note:
// Consul doesn't support more than one tag in its serviceName query method.
// https://github.com/hashicorp/consul/issues/294
// Hashi suggest prepared queries, but they don't support blocking.
// https://www.consul.io/docs/agent/http/query.html#execute
// If we want blocking for efficiency, we can use single tag
return func(ctx context.Context) (*discovery.Service, error) {
// Note: i.lastMeta is only updated in this function, and this function is executed via loop.Loop.
// because loop.Loop guarantees that all tasks are executed one-by-one,
// there is no need to use Lock or locking
lastIndex := defaultIndex
if i.lastMeta != nil {
lastIndex = i.lastMeta.LastIndex
}
opts := &api.QueryOptions{
WaitIndex: lastIndex,
}
//entries, meta, e := i.client.Service(i.serviceName, "", false, opts.WithContext(ctx))
entries, meta, e := i.consul.Client().Health().Service(i.ServiceName(), "", false, opts.WithContext(ctx))
i.lastMeta = meta
insts := makeInstances(entries, i.selector)
service := &discovery.Service{
Name: i.ServiceName(),
Insts: insts,
Time: time.Now(),
Err: e,
}
return service, e
}
}
/***********************
Helpers
***********************/
func makeInstances(entries []*api.ServiceEntry, selector discovery.InstanceMatcher) []*discovery.Instance {
instances := make([]*discovery.Instance, 0)
for _, entry := range entries {
addr := entry.Service.Address
if addr == "" {
addr = entry.Node.Address
}
inst := &discovery.Instance{
ID: entry.Service.ID,
Service: entry.Service.Service,
Address: addr,
Port: entry.Service.Port,
Tags: entry.Service.Tags,
Meta: entry.Service.Meta,
Health: parseHealth(entry),
RawEntry: entry,
}
if selector == nil {
instances = append(instances, inst)
} else if matched, e := selector.Matches(inst); e == nil && matched {
instances = append(instances, inst)
}
}
sort.SliceStable(instances, func(i, j int) bool {
return instances[i].ID < instances[j].ID
})
return instances
}
func parseHealth(entry *api.ServiceEntry) discovery.HealthStatus {
switch status := entry.Checks.AggregatedStatus(); status {
case api.HealthPassing:
return discovery.HealthPassing
case api.HealthWarning:
return discovery.HealthWarning
case api.HealthCritical:
return discovery.HealthCritical
case api.HealthMaint:
return discovery.HealthMaintenance
default:
return discovery.HealthAny
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulsd
import (
"context"
"embed"
"fmt"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/consul"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
"io"
)
var logger = log.New("SD.Consul")
//go:embed defaults-discovery.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "consul service discovery",
Precedence: bootstrap.ServiceDiscoveryPrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(
BindDiscoveryProperties,
fx.Annotate(discovery.NewBuildInfoCustomizer, fxRegistrationCustomizerGroupTag()),
fx.Annotate(providePropertiesBasedCustomizer, fxRegistrationCustomizerGroupTag()),
NewServiceRegistrar,
provideRegistration,
provideDiscoveryClient),
fx.Invoke(registerService, closeDiscoveryClient),
},
}
func Use() {
bootstrap.Register(Module)
}
func fxRegistrationCustomizerGroupTag() fx.Annotation {
return fx.ResultTags(fmt.Sprintf(`group:"%s"`, discovery.FxGroup))
}
func providePropertiesBasedCustomizer(appCtx *bootstrap.ApplicationContext) discovery.ServiceRegistrationCustomizer {
return discovery.NewPropertiesBasedCustomizer(appCtx, nil)
}
type regDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Props DiscoveryProperties
Customizers []discovery.ServiceRegistrationCustomizer `group:"discovery"`
}
func provideRegistration(di regDI) discovery.ServiceRegistration {
reg := NewRegistration(di.AppCtx,
RegistrationWithAppContext(di.AppCtx),
RegistrationWithProperties(&di.Props),
RegistrationWithCustomizers(di.Customizers...))
return reg
}
func provideDiscoveryClient(ctx *bootstrap.ApplicationContext, conn *consul.Connection, props DiscoveryProperties) discovery.Client {
return NewDiscoveryClient(ctx, conn, func(opt *ClientConfig) {
opt.DefaultSelector = InstanceWithProperties(&props.DefaultSelector)
})
}
func registerService(lc fx.Lifecycle, registrar discovery.ServiceRegistrar, registration discovery.ServiceRegistration) {
// because we are the lowest precedence, this is executed when every thing is ready
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
return registrar.Register(ctx, registration)
},
OnStop: func(ctx context.Context) error {
return registrar.Deregister(ctx, registration)
},
})
}
func closeDiscoveryClient(lc fx.Lifecycle, client discovery.Client) {
lc.Append(fx.StopHook(func(ctx context.Context) error {
return client.(io.Closer).Close()
}))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulsd
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"github.com/pkg/errors"
)
const (
PropertiesPrefix = "cloud.discovery.consul"
)
//goland:noinspection GoNameStartsWithPackageName
type DiscoveryProperties struct {
HealthCheckScheme string `json:"health-check-scheme"`
HealthCheckPath string `json:"health-check-path"`
HealthCheckPort int `json:"health-check-port"`
HealthCheckInterval string `json:"health-check-interval"`
Tags utils.CommaSeparatedSlice `json:"tags"`
AclToken string `json:"acl-token"`
IpAddress string `json:"ip-address"` //A pre-defined IP address
Interface string `json:"interface"` //The network interface from where to get the ip address. If IpAddress is defined, this field is ignored
Port int `json:"port"`
Scheme string `json:"scheme"`
HealthCheckCriticalTimeout string `json:"health-check-critical-timeout"` //See api.AgentServiceCheck's DeregisterCriticalServiceAfter field
DefaultSelector SelectorProperties `json:"default-selector"` // Default tags or meta to use when discovering other services
}
type SelectorProperties struct {
Tags utils.CommaSeparatedSlice `json:"tags"`
Meta map[string]string `json:"meta"`
}
func NewDiscoveryProperties() *DiscoveryProperties {
return &DiscoveryProperties{
Port: 0,
Scheme: "http",
HealthCheckInterval: "15s",
HealthCheckCriticalTimeout: "15s",
HealthCheckPath: fmt.Sprintf("%s", "/admin/health"),
}
}
func BindDiscoveryProperties(ctx *bootstrap.ApplicationContext) DiscoveryProperties {
props := NewDiscoveryProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind DiscoveryProperties"))
}
return *props
}
// InstanceWithProperties returns an InstanceMatcher that matches instances described in given selector properties
// could return nil
func InstanceWithProperties(props *SelectorProperties) discovery.InstanceMatcher {
matchers := make([]matcher.Matcher, 0, len(props.Tags)+len(props.Meta))
for _, tag := range props.Tags {
if len(tag) != 0 {
matchers = append(matchers, discovery.InstanceWithTag(tag, true))
}
}
for k, v := range props.Meta {
matchers = append(matchers, discovery.InstanceWithMetaKV(k, v))
}
if len(matchers) == 0 {
return nil
}
return matcher.And(matchers[0], matchers[1:]...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consulsd
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/consul"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
netutil "github.com/cisco-open/go-lanai/pkg/utils/net"
"github.com/hashicorp/consul/api"
)
func NewServiceRegistrar(conn *consul.Connection) discovery.ServiceRegistrar {
return consulServiceRegistrar{
conn: conn,
}
}
type consulServiceRegistrar struct {
conn *consul.Connection
}
func (r consulServiceRegistrar) Register(ctx context.Context, registration discovery.ServiceRegistration) error {
reg, ok := registration.(*ServiceRegistration)
if !ok {
return fmt.Errorf(`unsupported registration type [%T]`, registration)
}
if e := r.conn.Client().Agent().ServiceRegister(®.AgentServiceRegistration); e != nil {
return e
}
logger.WithContext(ctx).WithKV(r.registrationKVs(registration)...).Infof("Register")
return nil
}
func (r consulServiceRegistrar) Deregister(ctx context.Context, registration discovery.ServiceRegistration) error {
if e := r.conn.Client().Agent().ServiceDeregister(registration.ID()); e != nil {
return e
}
logger.WithContext(ctx).WithKV(r.registrationKVs(registration)...).Infof("Deregister")
return nil
}
func (r consulServiceRegistrar) registrationKVs(reg discovery.ServiceRegistration) []interface{} {
return []interface{}{
"service", reg.Name(),
"address", reg.Address(),
"port", reg.Port(),
"tags", reg.Tags(),
"meta", reg.Meta(),
}
}
// ServiceRegistration implements discovery.ServiceRegistration
type ServiceRegistration struct {
api.AgentServiceRegistration
}
func (r *ServiceRegistration) ID() string {
return r.AgentServiceRegistration.ID
}
func (r *ServiceRegistration) Name() string {
return r.AgentServiceRegistration.Name
}
func (r *ServiceRegistration) Address() string {
return r.AgentServiceRegistration.Address
}
func (r *ServiceRegistration) Port() int {
return r.AgentServiceRegistration.Port
}
func (r *ServiceRegistration) Tags() []string {
return r.AgentServiceRegistration.Tags
}
func (r *ServiceRegistration) Meta() (kvs map[string]any) {
kvs = make(map[string]any)
for k, v := range r.AgentServiceRegistration.Meta {
kvs[k] = v
}
return
}
func (r *ServiceRegistration) SetID(id string) {
r.AgentServiceRegistration.ID = id
}
func (r *ServiceRegistration) SetName(name string) {
r.AgentServiceRegistration.Name = name
}
func (r *ServiceRegistration) SetAddress(addr string) {
r.AgentServiceRegistration.Address = addr
}
func (r *ServiceRegistration) SetPort(port int) {
r.AgentServiceRegistration.Port = port
}
func (r *ServiceRegistration) AddTags(tags ...string) {
// add non-duplicate tags and preserve their original order
uniqueTags := utils.NewStringSet(r.AgentServiceRegistration.Tags...)
for _, t := range tags {
if uniqueTags.Has(t) {
continue
}
r.AgentServiceRegistration.Tags = append(r.AgentServiceRegistration.Tags, t)
uniqueTags.Add(t)
}
}
func (r *ServiceRegistration) RemoveTags(tags ...string) {
var head int
for i := range r.AgentServiceRegistration.Tags {
var found bool
for j := 0; j < len(tags) && !found; found, j = tags[j] == r.AgentServiceRegistration.Tags[i], j+1 {
}
if found {
continue
}
r.AgentServiceRegistration.Tags[head] = r.AgentServiceRegistration.Tags[i]
head++
}
r.AgentServiceRegistration.Tags = r.AgentServiceRegistration.Tags[:head]
}
func (r *ServiceRegistration) SetMeta(key string, value any) {
if r.AgentServiceRegistration.Meta == nil {
r.AgentServiceRegistration.Meta = make(map[string]string)
}
if value == nil {
delete(r.AgentServiceRegistration.Meta, key)
} else {
r.AgentServiceRegistration.Meta[key] = fmt.Sprintf(`%v`, value)
}
}
type RegistrationOptions func(cfg *RegistrationConfig)
type RegistrationConfig struct {
ApplicationName string
IPAddress string
NetworkInterface string
Port int
Tags []string
HealthCheckPath string
HealthPort int
HealthScheme string
HealthCheckInterval string
HealthCheckCriticalTimeout string
Customizers []discovery.ServiceRegistrationCustomizer
}
func NewRegistration(ctx context.Context, opts ...RegistrationOptions) discovery.ServiceRegistration {
cfg := RegistrationConfig{}
for _, fn := range opts {
fn(&cfg)
}
if len(cfg.IPAddress) == 0 {
cfg.IPAddress, _ = netutil.GetIp(cfg.NetworkInterface)
}
if cfg.HealthPort == 0 {
cfg.HealthPort = cfg.Port
}
reg := ServiceRegistration{
AgentServiceRegistration: api.AgentServiceRegistration{
Kind: api.ServiceKindTypical,
ID: fmt.Sprintf("%s-%d-%x", cfg.ApplicationName, cfg.Port, cryptoutils.RandomBytes(5)),
Name: cfg.ApplicationName,
Tags: cfg.Tags,
Port: cfg.Port,
Address: cfg.IPAddress,
Check: &api.AgentServiceCheck{
HTTP: fmt.Sprintf("%s://%s:%d%s", cfg.HealthScheme, cfg.IPAddress, cfg.HealthPort, cfg.HealthCheckPath),
Interval: cfg.HealthCheckInterval,
DeregisterCriticalServiceAfter: cfg.HealthCheckCriticalTimeout},
},
}
for _, c := range cfg.Customizers {
c.Customize(ctx, ®)
}
return ®
}
func RegistrationWithProperties(props *DiscoveryProperties) RegistrationOptions {
return func(cfg *RegistrationConfig) {
cfg.IPAddress = props.IpAddress
cfg.NetworkInterface = props.Interface
cfg.Port = props.Port
cfg.HealthCheckPath = props.HealthCheckPath
cfg.HealthScheme = props.Scheme
cfg.HealthCheckInterval = props.HealthCheckInterval
cfg.HealthCheckCriticalTimeout = props.HealthCheckCriticalTimeout
cfg.Tags = append(cfg.Tags, fmt.Sprintf("secure=%t", props.Scheme == "https"))
cfg.Tags = append(cfg.Tags, props.Tags...)
}
}
func RegistrationWithAppContext(appCtx *bootstrap.ApplicationContext) RegistrationOptions {
return func(cfg *RegistrationConfig) {
cfg.ApplicationName = appCtx.Name()
}
}
func RegistrationWithCustomizers(customizers ...discovery.ServiceRegistrationCustomizer) RegistrationOptions {
return func(cfg *RegistrationConfig) {
cfg.Customizers = append(cfg.Customizers, customizers...)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package discovery
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"strings"
"time"
)
const (
FxGroup = "discovery"
)
const (
TagInstanceUUID = `instanceUuid`
TagContextPath = `contextPath`
TagComponentAttributes = `componentAttributes`
TagServiceName = `name`
TagBuildVersion = `version`
TagBuildNumber = `buildNumber`
TagBuildDateTime = `buildDateTime`
TagSecure = `secure`
ComponentAttributeDelimiter = `~`
ComponentAttributeKeyValueSeparator = `:`
)
const (
InstanceMetaKeyVersion = `version`
InstanceMetaKeyContextPath = `context`
InstanceMetaKeySMCR = `SMCR`
//InstanceMetaKey = ``
)
const (
HealthAny HealthStatus = iota
HealthPassing
HealthWarning
HealthCritical
HealthMaintenance
)
var (
ErrInstancerStopped = fmt.Errorf("instancer is already stopped")
)
type Client interface {
Context() context.Context
Instancer(serviceName string) (Instancer, error)
}
// HealthStatus maintenance > critical > warning > passing
type HealthStatus int
type Service struct {
Name string
Insts []*Instance
Time time.Time
Err error
FirstErrAt time.Time
}
func (s *Service) Instances(selector InstanceMatcher) (ret []*Instance) {
for _, inst := range s.Insts {
if selector != nil {
if matched, e := selector.Matches(inst); e != nil || !matched {
continue
}
}
ret = append(ret, inst)
}
return
}
func (s *Service) InstanceCount(selector InstanceMatcher) (ret int) {
for _, inst := range s.Insts {
if selector != nil {
if matched, e := selector.Matches(inst); e != nil || !matched {
continue
}
}
ret++
}
return
}
type Instance struct {
ID string
Service string
Address string
Port int
Tags []string
Meta map[string]string
Health HealthStatus
RawEntry interface{}
}
type Callback func(Instancer)
// InstanceMatcher is a matcher.Matcher that takes Instance or *Instance
type InstanceMatcher matcher.ChainableMatcher
type Instancer interface {
ServiceName() string
Service() *Service
Instances(InstanceMatcher) ([]*Instance, error)
Start(ctx context.Context)
Stop()
RegisterCallback(id interface{}, cb Callback)
DeregisterCallback(id interface{})
}
// ServiceCache is not goroutine-safe unless the detail implementation says so
type ServiceCache interface {
// Get returns service with given service name. return nil if not exist
Get(name string) *Service
// Set stores given service with name, returns non-nil if the service is already exists
Set(name string, svc *Service) *Service
// SetWithTTL stores given service with name and TTL, returns non-nil if the service is already exists
// if ttl is zero or negative value, it's equivalent to Set
SetWithTTL(name string, svc *Service, ttl time.Duration) *Service
Has(name string) bool
Entries() map[string]*Service
}
/*************************
Common Impl
*************************/
var (
healthyInstanceMatcher = &instanceMatcher{
desc: "is healthy",
matchFunc: func(_ context.Context, instance *Instance) (bool, error) {
return instance.Health == HealthPassing, nil
},
}
)
// InstanceIsHealthy returns an InstanceMatcher that matches healthy instances
func InstanceIsHealthy() InstanceMatcher {
return healthyInstanceMatcher
}
func InstanceWithVersion(verPattern string) InstanceMatcher {
return &instanceMatcher{
desc: fmt.Sprintf("of version %s", verPattern),
matchFunc: func(_ context.Context, instance *Instance) (bool, error) {
if instance.Meta == nil {
return false, nil
}
ver, ok := instance.Meta[InstanceMetaKeyVersion]
return ok && ver == verPattern, nil
},
}
}
func InstanceWithHealth(status HealthStatus) InstanceMatcher {
return &instanceMatcher{
desc: fmt.Sprintf("with health status %d", status),
matchFunc: func(_ context.Context, instance *Instance) (bool, error) {
return status == HealthAny || instance.Health == status, nil
},
}
}
func InstanceWithMetaKV(key, value string) InstanceMatcher {
return &instanceMatcher{
desc: fmt.Sprintf("has meta %s=%s", key, value),
matchFunc: func(_ context.Context, instance *Instance) (bool, error) {
if instance.Meta == nil {
return false, nil
}
v, ok := instance.Meta[key]
return ok && (value == "" || value == v), nil
},
}
}
func InstanceWithTag(tag string, caseInsensitive bool) InstanceMatcher {
return &instanceMatcher{
desc: fmt.Sprintf("with tag %s", tag),
matchFunc: func(_ context.Context, instance *Instance) (bool, error) {
if instance.Tags == nil {
return false, nil
}
for _, t := range instance.Tags {
if t == tag || caseInsensitive && strings.EqualFold(t, tag) {
return true, nil
}
}
return false, nil
},
}
}
func InstanceWithTagKV(key, value string, caseInsensitive bool) InstanceMatcher {
if caseInsensitive {
key = strings.ToLower(key)
value = strings.ToLower(value)
}
return &instanceMatcher{
desc: fmt.Sprintf("with tag %s=%s", key, value),
matchFunc: func(_ context.Context, instance *Instance) (bool, error) {
if instance.Tags == nil {
return false, nil
}
for _, tag := range instance.Tags {
if caseInsensitive {
tag = strings.ToLower(tag)
}
kv := strings.SplitN(strings.TrimSpace(tag), "=", 2)
if len(kv) == 2 && kv[0] == key && kv[1] == value {
return true, nil
}
}
return false, nil
},
}
}
// instanceMatcher implements InstanceMatcher and accept *Instance and Instance
type instanceMatcher struct {
matchFunc func(context.Context, *Instance) (bool, error)
desc string
}
func (m *instanceMatcher) Matches(i interface{}) (bool, error) {
return m.MatchesWithContext(context.TODO(), i)
}
func (m *instanceMatcher) MatchesWithContext(c context.Context, i interface{}) (ret bool, err error) {
var inst *Instance
switch v := i.(type) {
case *Instance:
inst = v
case Instance:
inst = &v
default:
return false, fmt.Errorf("expect *Instance but got %T", i)
}
return m.matchFunc(c, inst)
}
func (m *instanceMatcher) Or(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.Or(m, matchers...)
}
func (m *instanceMatcher) And(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.And(m, matchers...)
}
func (m *instanceMatcher) String() string {
return m.desc
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package discovery
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/google/uuid"
"strings"
)
// NewBuildInfoCustomizer returns a ServiceRegistrationCustomizer that extract service's build information
// and put it in tags and metadata
func NewBuildInfoCustomizer() ServiceRegistrationCustomizer {
return ServiceRegistrationCustomizerFunc(func(_ context.Context, reg ServiceRegistration) {
attrs := map[string]string{
TagBuildVersion: bootstrap.BuildVersion,
TagBuildDateTime: bootstrap.BuildTime,
}
components := strings.Split(bootstrap.BuildVersion, "-")
if len(components) == 2 {
attrs[TagBuildNumber] = components[1]
}
for k, v := range attrs {
reg.SetMeta(k, v)
reg.AddTags(fmt.Sprintf("%s=%s", k, v))
}
})
}
var defaultPropertyPaths = map[string]string{
`serviceName`: `application.name`,
`context`: `server.context-path`,
`name`: `info.app.attributes.displayName`,
`description`: `info.app.description`,
`parent`: `info.app.attributes.parent`,
`type`: `info.app.attributes.type`,
}
// NewPropertiesBasedCustomizer returns a ServiceRegistrationCustomizer that populate tags and metadata
// based on service's loaded properties and the given "propertyPaths".
// "propertyPaths" is a map that contains metadata key as "key" and its corresponding property path.
func NewPropertiesBasedCustomizer(appCtx *bootstrap.ApplicationContext, propertyPaths map[string]string) ServiceRegistrationCustomizer {
if propertyPaths == nil {
propertyPaths = defaultPropertyPaths
}
return ServiceRegistrationCustomizerFunc(func(ctx context.Context, reg ServiceRegistration) {
tags := make([]string, 0, len(propertyPaths))
attrs := make([]string, 0, len(propertyPaths))
// static KVs
id := uuid.New()
ctxPath, _ := appCtx.Value(`server.context-path`).(string)
tags = append(tags, kvTag(TagInstanceUUID, id.String()))
tags = append(tags, kvTag(TagServiceName, appCtx.Name()))
tags = append(tags, kvTag(TagContextPath, ctxPath))
reg.SetMeta(TagInstanceUUID, id)
reg.SetMeta(TagServiceName, appCtx.Name())
reg.SetMeta(TagContextPath, ctxPath)
// extract properties
for key, path := range propertyPaths {
value := appCtx.Value(path)
if value != nil {
reg.SetMeta(key, value)
attrs = append(attrs, fmt.Sprintf("%s%s%v", key, ComponentAttributeKeyValueSeparator, value))
}
}
// set tags
tags = append(tags, kvTag(TagComponentAttributes, strings.Join(attrs, ComponentAttributeDelimiter)))
reg.AddTags(tags...)
})
}
func kvTag(k string, v string) string {
return fmt.Sprintf("%s=%s", k, v)
}
package dnssd
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils/loop"
"regexp"
"sync"
"time"
)
type ClientOptions func(opt *ClientConfig)
type ClientConfig struct {
Logger log.ContextualLogger
Verbose bool
// DNSServerAddr is the address and port of DNS server. e.g. "8.8.8.8:53"
DNSServerAddr string
// FQDNTemplate see DiscoveryProperties.FQDNTemplate
FQDNTemplate string
// SRVProto see DiscoveryProperties.SRVProto
SRVProto string
// SRVService see DiscoveryProperties.SRVService
SRVService string
// RefreshInterval interval for background refresher.
// Note: Foreground refresh happens everytime when Instancer.Service or Instancer.Instances is invoked.
// Background refresh is for callbacks only
// Default: 30s
RefreshInterval time.Duration
// FallbackHostMappings known host mappings including default template. Used when DNS lookup fails.
// See DiscoveryProperties.Fallback.
// Note: Default mapping should be at the end of the list with `.*` as service name pattern
FallbackHostMappings []HostMapping
}
type HostMapping struct {
// ServiceRegex the compiled regular expression HostMappingProperties.Service
ServiceRegex *regexp.Regexp
// Hosts is a list of known hosts. Each entry should be a golang template with single-line output.
// See HostMappingProperties.Hosts
Hosts []string
}
type dnsDiscoveryClient struct {
ctx context.Context
instancers map[string]*Instancer
mutex sync.Mutex
config ClientConfig
}
func NewDiscoveryClient(ctx context.Context, opts ...ClientOptions) discovery.Client {
client := dnsDiscoveryClient{
ctx: ctx,
instancers: map[string]*Instancer{},
config: ClientConfig{
Logger: logger,
Verbose: false,
RefreshInterval: defaultRefreshInterval,
},
}
for _, fn := range opts {
fn(&client.config)
}
return &client
}
func (c *dnsDiscoveryClient) Context() context.Context {
return c.ctx
}
func (c *dnsDiscoveryClient) Instancer(serviceName string) (discovery.Instancer, error) {
if serviceName == "" {
return nil, fmt.Errorf("empty service name")
}
c.mutex.Lock()
defer c.mutex.Unlock()
instancer, ok := c.instancers[serviceName]
if ok {
return instancer, nil
}
var e error
instancer, e = NewInstancer(c.ctx, func(opt *InstancerOption) {
opt.Name = serviceName
opt.Logger = c.config.Logger
opt.Verbose = c.config.Verbose
opt.DNSServerAddr = c.config.DNSServerAddr
opt.FQDNTemplate = c.config.FQDNTemplate
opt.SRVProto = c.config.SRVProto
opt.SRVService = c.config.SRVService
opt.HostTemplates = c.findFallbackHostTemplates(serviceName)
opt.RefresherOptions = []loop.TaskOptions{loop.FixedRepeatInterval(c.config.RefreshInterval)}
})
if e == nil {
c.instancers[serviceName] = instancer
}
return instancer, e
}
func (c *dnsDiscoveryClient) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
for _, v := range c.instancers {
v.Stop()
}
return nil
}
func (c *dnsDiscoveryClient) findFallbackHostTemplates(svcName string) []string {
for _, mapping := range c.config.FallbackHostMappings {
if mapping.ServiceRegex.Match([]byte(svcName)) {
return mapping.Hosts
}
}
return nil
}
package dnssd
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/discovery/sd"
"github.com/cisco-open/go-lanai/pkg/utils/loop"
"net"
"regexp"
"sort"
"strconv"
"strings"
"text/template"
"time"
)
const (
kMetaSRVName = "_srv_name"
kMetaSRVService = "_srv_service"
kMetaSRVProto = "_srv_proto"
kMetaScheme = "scheme"
kTagSecure = "secure"
kTagInsecure = "insecure"
)
var (
defaultRefreshInterval = 30 * time.Second
defaultLookupTimeout = 2 * time.Second
hostPatternRegexp = regexp.MustCompile(`((?P<scheme>\w+)://)?(?P<host>.+)`)
)
type InstancerOptions func(opt *InstancerOption)
type InstancerOption struct {
sd.InstancerOption
DNSServerAddr string
FQDNTemplate string
SRVProto string
SRVService string
HostTemplates []string
}
type Instancer struct {
sd.CachedInstancer
context context.Context
resolver *net.Resolver
fqdn string
srvProto string
srvService string
fallback []*discovery.Instance
}
func NewInstancer(ctx context.Context, opts ...InstancerOptions) (*Instancer, error) {
opt := InstancerOption{
InstancerOption: sd.InstancerOption{
Logger: logger,
RefresherOptions: []loop.TaskOptions{loop.FixedRepeatInterval(defaultRefreshInterval)},
},
}
for _, f := range opts {
f(&opt)
}
var dial func(ctx context.Context, network, address string) (net.Conn, error)
if len(opt.DNSServerAddr) != 0 {
dial = dialWithAddrOverride(opt.DNSServerAddr)
}
fqdn, e := fqdnWithTemplate(&opt)
if e != nil {
return nil, fmt.Errorf(`failed to execute FQDN template "%s": %v`, opt.FQDNTemplate, e)
}
fallback, e := staticInstancesWithTemplates(&opt)
if e != nil {
return nil, fmt.Errorf(`failed to process fallback: %v`, e)
}
i := &Instancer{
CachedInstancer: sd.MakeCachedInstancer(func(baseOpt *sd.CachedInstancerOption) {
baseOpt.InstancerOption = opt.InstancerOption
}),
context: ctx,
resolver: &net.Resolver{
PreferGo: dial != nil,
Dial: dial,
},
fqdn: fqdn,
srvProto: strings.TrimLeft(strings.TrimSpace(opt.SRVProto), "_"),
srvService: strings.TrimLeft(strings.TrimSpace(opt.SRVService), "_"),
fallback: fallback,
}
i.BackgroundRefreshFunc = i.resolveInstancesTask()
i.ForegroundRefreshFunc = i.resolveInstancesTask()
i.Start(ctx)
return i, nil
}
func (i *Instancer) Service() (svc *discovery.Service) {
_, _ = i.RefreshNow(i.context)
return i.CachedInstancer.Service()
}
func (i *Instancer) Instances(matcher discovery.InstanceMatcher) (ret []*discovery.Instance, err error) {
_, _ = i.RefreshNow(i.context)
return i.CachedInstancer.Instances(matcher)
}
func (i *Instancer) resolveInstancesTask() func(ctx context.Context) (*discovery.Service, error) {
return func(ctx context.Context) (*discovery.Service, error) {
instances, e := i.trySRVRecord(ctx)
if (e != nil || len(instances) == 0) && len(i.fallback) != 0 {
instances = i.makeInstancesFromFallback()
e = nil
}
svc := &discovery.Service{
Name: i.Name,
Insts: instances,
Time: time.Now(),
Err: e,
}
return svc, e
}
}
func (i *Instancer) trySRVRecord(ctx context.Context) ([]*discovery.Instance, error) {
ctx, cancel := context.WithTimeout(ctx, defaultLookupTimeout)
defer cancel()
name, srvs, e := i.resolver.LookupSRV(ctx, i.srvService, i.srvProto, i.fqdn)
e = i.translateLookupError(ctx, e)
if e != nil {
return nil, e
}
return i.makeInstancesFromSRVs(name, srvs), nil
}
func (i *Instancer) makeInstancesFromSRVs(name string, srvs []*net.SRV) []*discovery.Instance {
instances := make([]*discovery.Instance, len(srvs))
for j := range srvs {
instances[j] = &discovery.Instance{
ID: net.JoinHostPort(srvs[j].Target, strconv.Itoa(int(srvs[j].Port))),
Service: i.Name,
Address: srvs[j].Target,
Port: int(srvs[j].Port),
Meta: map[string]string{
kMetaSRVService: i.srvService,
kMetaSRVProto: i.srvProto,
kMetaSRVName: name,
},
Health: discovery.HealthPassing,
RawEntry: *srvs[j],
}
}
sort.SliceStable(instances, func(i, j int) bool {
return instances[i].ID < instances[j].ID
})
return instances
}
func (i *Instancer) makeInstancesFromFallback() []*discovery.Instance {
// just make a shallow copy
instances := make([]*discovery.Instance, len(i.fallback))
for j := range i.fallback {
instances[j] = i.fallback[j]
}
return instances
}
func (i *Instancer) translateLookupError(ctx context.Context, err error) error {
if err == nil {
return nil
}
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) {
if dnsErr.IsNotFound {
return nil
}
}
i.logError(ctx, err)
return err
}
func (i *Instancer) logError(ctx context.Context, err error) {
if i.Verbose {
i.Logger.WithContext(ctx).Debugf(`failed to lookup %s %s %s IN SRV: %v`, i.srvService, i.srvProto, i.fqdn, err)
}
}
/*******************
Helpers
*******************/
func dialWithAddrOverride(addr string) func(ctx context.Context, network, address string) (net.Conn, error) {
return func(ctx context.Context, network, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, addr)
}
}
type tmplData struct {
ServiceName string
}
func execTemplate(tmplText string, opt *InstancerOption) (string, error) {
tmpl, e := template.New("single-line").Parse(tmplText)
if e != nil {
return "", e
}
var buf bytes.Buffer
data := tmplData{
ServiceName: opt.Name,
}
if e := tmpl.Execute(&buf, data); e != nil {
return "", e
}
return buf.String(), nil
}
func fqdnWithTemplate(opt *InstancerOption) (string, error) {
host, e := execTemplate(opt.FQDNTemplate, opt)
if e != nil {
return "", e
}
fqdn, _, e := splitAddrAndPort(host)
return fqdn, e
}
func staticInstancesWithTemplates(opt *InstancerOption) ([]*discovery.Instance, error) {
instances := make([]*discovery.Instance, len(opt.HostTemplates))
for j, tmplText := range opt.HostTemplates {
host, e := execTemplate(tmplText, opt)
if e != nil {
return nil, e
}
scheme, addr, port, e := parseHostStringWithScheme(host)
if e != nil {
return nil, fmt.Errorf(`unable to parse host "%s": %v`, host, e)
}
instances[j] = &discovery.Instance{
ID: host,
Service: opt.Name,
Address: addr,
Port: port,
Meta: map[string]string{},
Health: discovery.HealthPassing,
RawEntry: host,
}
if len(scheme) != 0 {
instances[j].Meta[kMetaScheme] = scheme
if scheme == "http" {
instances[j].Tags = append(instances[j].Tags, kTagInsecure+"=true", kTagSecure+"=false")
} else {
instances[j].Tags = append(instances[j].Tags, kTagInsecure+"=false", kTagSecure+"=true")
}
}
}
sort.SliceStable(instances, func(i, j int) bool {
return instances[i].ID < instances[j].ID
})
return instances, nil
}
func splitAddrAndPort(value string) (string, int, error) {
switch i := strings.LastIndexByte(value, ':'); {
case i < 0:
return value, 0, nil
default:
addr, portStr, e := net.SplitHostPort(value)
if e != nil {
return "", 0, e
}
port, e := strconv.Atoi(portStr)
if e != nil {
return "", 0, e
}
return addr, port, nil
}
}
func parseHostStringWithScheme(value string) (scheme, hostname string, port int, err error) {
var host string
match := hostPatternRegexp.FindStringSubmatch(value)
for i, name := range hostPatternRegexp.SubexpNames() {
switch name {
case "scheme":
scheme = strings.ToLower(match[i])
case "host":
host = strings.ToLower(match[i])
}
}
hostname, port, err = splitAddrAndPort(host)
return
}
package dnssd
import (
"context"
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
"io"
)
var logger = log.New("SD.DNS")
//go:embed defaults-discovery.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "consul service discovery",
Precedence: bootstrap.ServiceDiscoveryPrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(
BindDiscoveryProperties,
provideDiscoveryClient),
fx.Invoke(closeDiscoveryClient),
},
}
func Use() {
bootstrap.Register(Module)
}
func provideDiscoveryClient(ctx *bootstrap.ApplicationContext, props DiscoveryProperties) (discovery.Client, error) {
mappings, e := props.Fallback.CompileMappings()
if e != nil {
return nil, e
}
return NewDiscoveryClient(ctx, func(opt *ClientConfig) {
opt.DNSServerAddr = props.Addr
opt.FQDNTemplate = props.FQDNTemplate
opt.SRVProto = props.SRVProto
opt.SRVService = props.SRVService
opt.FallbackHostMappings = mappings
}), nil
}
func closeDiscoveryClient(lc fx.Lifecycle, client discovery.Client) {
lc.Append(fx.StopHook(func(ctx context.Context) error {
return client.(io.Closer).Close()
}))
}
package dnssd
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/pkg/errors"
"regexp"
)
const (
PropertiesPrefix = "cloud.discovery.dns"
)
// DiscoveryProperties defines static configuration of DNS SRV lookup
// SRV lookup is in format of "_<service>._<proto>.<target>" or "<target>"
// e.g. _http._tcp.my-service.my-namespace.svc.cluster.local
//
// my-service.my-namespace.svc.cluster.local
//
// See [RFC2782](https://datatracker.ietf.org/doc/html/rfc2782)
type DiscoveryProperties struct {
// Addr is the address of DNS server. e.g. "8.8.8.8:53"
// If not set, default DNS server is used.
// Note: Resolving DNS server address may also require DNS lookup. Please set this value with caution
Addr string `json:"addr"`
// FQDNTemplate is a golang template with single-line to define how to
// translate service name to target domain name (<target>) in DNS lookup query.
// The template data contains field ".ServiceName".
// e.g. "{{.ServiceName}}.my-namespace.svc.cluster.local"
FQDNTemplate string `json:"fqdn-template"`
// SRVProto is the "Proto" defined in RFC2782 (The symbolic name of the desired protocol).
// When present, the value need to be prepended with underscore "_".
// e.g. "_tcp", "_udp"
// Optional, when specified, SRVService should also be specified
SRVProto string `json:"srv-proto"`
// SRVService is the "Service" defined in RFC2782 (The symbolic name of the desired service).
// When present, the value need to be prepended with underscore "_", And depending on the deployment environment,
// this could have different values.
// e.g. Kubernetes define this value to be the "port name", and Consul doesn't support "Proto" and "Service" in static DNS queries
// Optional, when specified, SRVProto should also be specified
SRVService string `json:"srv-service"`
// Fallback defines how does service discovery behave in case of DNS lookup couldn't resolve any instances.
// See FallbackProperties for more details
Fallback FallbackProperties `json:"fallback"`
}
// FallbackProperties defines host rewrite as the last resort
// In case DNS lookup fails, discovery client would try following:
// 1. If the service name matches any entry in Mappings, its HostMappingProperties.Hosts is used as healthy instances list.
// This is equivalent to static service discovery.
// 2. If step 1 yield no result, Default is used and the result would be a resolved service with single instance.
// This is equivalent to use server-side load balancing.
// 3. If none of above yield valid result, the original DNS lookup error is recorded and the service is temporarily undiscoverable.
type FallbackProperties struct {
// Mappings defines how to map the service names to hosts, based on service name patterns.
// The keys of the map is literal and does not affect mapping behaviour
// All entries in Mappings are tried in an undefined order, so make sure they don't have overlapped patterns
Mappings map[string]HostMappingProperties `json:"mappings"`
// Default is a golang template with single-line output to rewrite any service name into host.
// The template data contains field ".ServiceName".
// e.g. "{{.ServiceName}}.default.svc.cluster.local:8443"
// This value is used when the service name is not applicable to any entry in Mappings.
// The value can contain scheme, hostname and port, in which "hostname" is required.
// e.g. "http://{{.ServiceName}}.default.svc.cluster.local:8443"
Default string `json:"default"`
}
func (p FallbackProperties) CompileMappings() ([]HostMapping, error) {
mappings := make([]HostMapping, 0, len(p.Mappings)+1)
for _, m := range p.Mappings {
regex, e := regexp.CompilePOSIX(m.Service)
if e != nil {
return nil, fmt.Errorf(`invalid service name pattern "%s": %v`, m.Service, e)
}
mappings = append(mappings, HostMapping{ServiceRegex: regex, Hosts: m.Hosts})
}
if len(p.Default) != 0 {
mappings = append(mappings, HostMapping{ServiceRegex: regexp.MustCompilePOSIX(`.+`), Hosts: []string{p.Default}})
}
return mappings, nil
}
type HostMappingProperties struct {
// Service the name of the service, support regular expression
Service string `json:"service"`
// Hosts is a list of known hosts. Each entry should be a golang template with single-line output.
// The template data contains field ".ServiceName"
// e.g. "pod-1.{{.ServiceName}}.default.svc.cluster.local:8989"
// The value can contain scheme, hostname and port, in which "hostname" is required.
// e.g. "http://{{.ServiceName}}.default.svc.cluster.local:8443"
Hosts []string `json:"hosts"`
}
func NewDiscoveryProperties() *DiscoveryProperties {
return &DiscoveryProperties{
FQDNTemplate: "{{.ServiceName}}.default.svc.cluster.local",
}
}
func BindDiscoveryProperties(ctx *bootstrap.ApplicationContext) DiscoveryProperties {
props := NewDiscoveryProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind DiscoveryProperties"))
}
return *props
}
package discovery
import (
"context"
)
// ServiceRegistration is the data to be registered with any external service registration system.
// It contains information about current running service instance.
// The implementation depends on which service discovery tech-stack is used.
// e.g. Consul would be *consulsd.ServiceRegistration
type ServiceRegistration interface {
ID() string
Name() string
Address() string
Port() int
Tags() []string
Meta() map[string]any
SetID(id string)
SetName(name string)
SetAddress(addr string)
SetPort(port int)
AddTags(tags...string)
RemoveTags(tags...string)
SetMeta(key string, value any)
}
// ServiceRegistrar is the interface to interact with external service registration system.
type ServiceRegistrar interface {
Register(ctx context.Context, registration ServiceRegistration) error
Deregister(ctx context.Context, registration ServiceRegistration) error
}
// ServiceRegistrationCustomizer customize given ServiceRegistration during bootstrap.
// Any ServiceRegistrationCustomizer provided with fx group defined as FxGroup will be applied automatically.
type ServiceRegistrationCustomizer interface {
Customize(ctx context.Context, reg ServiceRegistration)
}
// ServiceRegistrationCustomizerFunc is the func that implements ServiceRegistrationCustomizer
type ServiceRegistrationCustomizerFunc func(ctx context.Context, reg ServiceRegistration)
func (fn ServiceRegistrationCustomizerFunc) Customize(ctx context.Context, reg ServiceRegistration) {
fn(ctx, reg)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sd
import (
"github.com/cisco-open/go-lanai/pkg/discovery"
"time"
)
// simpleServiceCache implements ServiceCache with map[string]*Service as a back storage
// simpleServiceCache is not goroutine-safe
type simpleServiceCache struct {
cache map[string]*discovery.Service
exp map[string]time.Time
}
// NewSimpleServiceCache returns a ServiceCache with map[string]*Service as a back storage
// This ServiceCache is not goroutine-safe
func NewSimpleServiceCache() discovery.ServiceCache {
// prepare cache
return &simpleServiceCache{
cache: map[string]*discovery.Service{},
exp: map[string]time.Time{},
}
}
func (c *simpleServiceCache) Get(key string) *discovery.Service {
if c.isExpired(key, time.Now()) {
return nil
}
return c.cache[key]
}
func (c *simpleServiceCache) Set(key string, svc *discovery.Service) *discovery.Service {
existing := c.Get(key)
c.cache[key] = svc
return existing
}
func (c *simpleServiceCache) SetWithTTL(key string, svc *discovery.Service, ttl time.Duration) *discovery.Service {
if ttl <= 0 {
return c.Set(key, svc)
}
existing := c.Get(key)
c.cache[key] = svc
c.exp[key] = time.Now().Add(ttl)
return existing
}
func (c *simpleServiceCache) Has(key string) bool {
v := c.Get(key)
return v != nil
}
func (c *simpleServiceCache) Entries() map[string]*discovery.Service {
ret := make(map[string]*discovery.Service)
now := time.Now()
for k, v := range c.cache {
if c.isExpired(k, now) {
continue
}
ret[k] = v
}
return ret
}
func (c *simpleServiceCache) isExpired(key string, now time.Time) bool {
exp, ok := c.exp[key]
if ok && exp.Before(now) {
delete(c.cache, key)
delete(c.exp, key)
return true
}
return false
}
/***********************
Helpers
***********************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package sd, provide base implementation of discovery.Client and discovery.Instancer.
package sd
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils/loop"
"reflect"
"sync"
"time"
)
const (
DefaultRefreshBackoffFactor float64 = 1.5
)
// InstancerOption general options applicable to any cache based instancer implementation
type InstancerOption struct {
// Name service name
Name string
// Logger logger to use for
Logger log.ContextualLogger
// Verbose whether to create logs for internal state changes
Verbose bool
// RefresherOptions controls how service refresher to retry refreshing attempt in case of failure.
// Default is loop.ExponentialRepeatIntervalOnError with initial interval 50ms and factor 1.5
RefresherOptions []loop.TaskOptions
}
// RefreshFunc is the function to use when CachedInstancer try to refresh internal cache.
// The refresh result is automatically saved and notified (if there is change)
type RefreshFunc func(ctx context.Context) (*discovery.Service, error)
// CachedInstancerOptions used to create CachedInstancer
type CachedInstancerOptions func(opt *CachedInstancerOption)
type CachedInstancerOption struct {
InstancerOption
// BackgroundRefreshFunc is used for background cache refresh.
// This function is repeated in a loop.Loop, so following behavior is guaranteed:
// - Sequential execution is guaranteed, no concurrent conditions between execution need to considered.
// - Next execution is not scheduled until current execution finishes. So this function can block current goroutine.
// The repeat interval is configurable via InstancerOption.RefresherOptions
BackgroundRefreshFunc RefreshFunc
// ForegroundRefreshFunc is used for foreground cache refresh, directly invoked via CachedInstancer.RefreshNow.
// This function always invoked in current goroutine. So concurrent conditions need to be considered,
// and cancellation/timeout need to be taken care of
ForegroundRefreshFunc RefreshFunc
}
// CachedInstancer implements discovery.Instancer
// CachedInstancer provides common implementation of discovery.Instancer with an internal cache and a background goroutine
// to periodically refresh service cache using provided RefreshFunc.
// See discovery.Instancer
type CachedInstancer struct {
CachedInstancerOption
ctx context.Context
readyCond *sync.Cond
cacheMtx sync.RWMutex // RW Lock for cache
stateMtx sync.RWMutex // RW Mutex for state, such as start/stop, callback/subscription update
looper *loop.Loop
loopCtx context.Context
cancelFunc context.CancelFunc
cache discovery.ServiceCache
callbacks map[interface{}]discovery.Callback
}
// MakeCachedInstancer returns a CachedInstancer that provide basic implementation of discovery.Instancer
// See discovery.Instancer
func MakeCachedInstancer(opts ...CachedInstancerOptions) CachedInstancer {
opt := CachedInstancerOption{
InstancerOption: InstancerOption{
RefresherOptions: []loop.TaskOptions{
loop.ExponentialRepeatIntervalOnError(50*time.Millisecond, DefaultRefreshBackoffFactor),
},
},
}
for _, f := range opts {
f(&opt)
}
return CachedInstancer{
CachedInstancerOption: opt,
looper: loop.NewLoop(),
cache: NewSimpleServiceCache(),
callbacks: map[interface{}]discovery.Callback{},
}
}
// ServiceName implements discovery.Instancer
func (i *CachedInstancer) ServiceName() string {
return i.Name
}
// Service implements discovery.Instancer
func (i *CachedInstancer) Service() (svc *discovery.Service) {
// read lock only
i.cacheMtx.RLock()
defer i.cacheMtx.RUnlock()
return i.service()
}
// Instances implements discovery.Instancer
func (i *CachedInstancer) Instances(matcher discovery.InstanceMatcher) (ret []*discovery.Instance, err error) {
// read lock only
i.cacheMtx.RLock()
defer i.cacheMtx.RUnlock()
svc := i.service()
if errors.Is(i.loopCtx.Err(), context.Canceled) {
// looper is stopped, we can't trust our cached result anymore
return []*discovery.Instance{}, discovery.ErrInstancerStopped
} else if svc.Err != nil {
err = svc.Err
}
ret = svc.Instances(matcher)
return
}
func (i *CachedInstancer) Start(ctx context.Context) {
i.stateMtx.Lock()
defer i.stateMtx.Unlock()
if i.loopCtx != nil {
return
}
i.readyCond = sync.NewCond(i.cacheMtx.RLocker())
i.loopCtx, i.cancelFunc = i.looper.Run(ctx)
i.looper.Repeat(i.refreshTask(), i.RefresherOptions...)
}
func (i *CachedInstancer) RegisterCallback(id interface{}, cb discovery.Callback) {
if id == nil || cb == nil {
return
}
i.stateMtx.Lock()
i.callbacks[id] = cb
i.stateMtx.Unlock()
//cb(i)
}
func (i *CachedInstancer) DeregisterCallback(id interface{}) {
if id == nil {
return
}
i.stateMtx.Lock()
defer i.stateMtx.Unlock()
delete(i.callbacks, id)
}
// Stop implements discovery.Instancer.
func (i *CachedInstancer) Stop() {
i.stateMtx.Lock()
defer i.stateMtx.Unlock()
if i.cancelFunc != nil {
i.cancelFunc()
}
}
// RefreshNow invoke refresh task immediately in current goroutine.
// Note: refresh function is run in current goroutine
func (i *CachedInstancer) RefreshNow(ctx context.Context) (*discovery.Service, error) {
return i.refresh(ctx, i.ForegroundRefreshFunc)
}
// service is not goroutine-safe and returns non-nil *Service.
// It would wait until first RefreshFunc finished and *Service become available
func (i *CachedInstancer) service() (svc *discovery.Service) {
for !i.cache.Has(i.Name) {
i.readyCond.Wait()
}
return i.cache.Get(i.Name)
}
func (i *CachedInstancer) refreshTask() loop.TaskFunc {
return func(ctx context.Context, _ *loop.Loop) (ret interface{}, err error) {
return i.refresh(ctx, i.BackgroundRefreshFunc)
}
}
func (i *CachedInstancer) refresh(ctx context.Context, fn RefreshFunc) (*discovery.Service, error) {
service, e := fn(ctx)
i.onRefresh(ctx, service, e)
return service, e
}
func (i *CachedInstancer) onRefresh(ctx context.Context, service *discovery.Service, err error) {
i.cacheMtx.Lock()
var notify bool
defer func() {
// we need to release the write lock before invoking callbacks
i.cacheMtx.Unlock()
i.readyCond.Broadcast()
if notify {
i.invokeCallbacks()
}
}()
// record result
existing := i.cache.Set(service.Name, service)
service.FirstErrAt = i.determineFirstErrTime(err, existing)
notify = i.shouldNotify(service, existing)
if notify {
i.logUpdate(ctx, service, existing)
}
}
// invokeCallbacks uses read lock
func (i *CachedInstancer) invokeCallbacks() {
i.stateMtx.RLock()
defer i.stateMtx.RUnlock()
for _, cb := range i.callbacks {
cb(i)
}
}
func (i *CachedInstancer) determineFirstErrTime(err error, old *discovery.Service) time.Time {
switch {
case err == nil:
// happy path, there is no new error, zero time
return time.Time{}
case old == nil || old.Err == nil:
// old record had no error, the err is the first err
return time.Now()
default:
// old record had error, carry over the old error time
return old.FirstErrAt
}
}
func (i *CachedInstancer) shouldNotify(new, old *discovery.Service) bool {
switch {
case old == nil && new == nil:
return false
case old == nil || new == nil:
return true
}
// notify with 3 conditions:
// 1. service instances changed
// 2. new service have error and old doesn't
// 3. old service have error but new doesn't
return !reflect.DeepEqual(new.Insts, old.Insts) ||
new.Err != nil && old.Err == nil ||
new.Err == nil && old.Err != nil
}
func (i *CachedInstancer) logUpdate(ctx context.Context, new, old *discovery.Service) {
if i.Verbose {
i.verboseLog(ctx, new, old)
}
// for regular log, we only log if healthy service changes between 0 and non-zero
var before, now int
if old != nil {
before = old.InstanceCount(discovery.InstanceIsHealthy())
}
if new != nil {
now = new.InstanceCount(discovery.InstanceIsHealthy())
}
if before == 0 && now > 0 {
i.Logger.WithContext(ctx).Infof("service [%s] became available", i.Name)
} else if before > 0 && now == 0 {
i.Logger.WithContext(ctx).Warnf("service [%s] healthy instances dropped to 0", i.Name)
}
}
func (i *CachedInstancer) verboseLog(ctx context.Context, new, old *discovery.Service) {
// verbose
if new != nil && new.Err != nil && (old == nil || old.Err == nil) {
i.Logger.WithContext(ctx).Infof("error when finding instances for service %s: %v", i.Name, new.Err)
} else {
diff := diff(new, old)
i.Logger.WithContext(ctx).Debugf(`refreshed instances %s: [healthy=%d] [unchanged=%d] [updated=%d] [new=%d] [removed=%d]`, i.Name,
len(diff.healthy), len(diff.unchanged), len(diff.updated), len(diff.added), len(diff.deleted))
}
}
/***********************
Helpers
***********************/
type svcDiff struct {
healthy,
unchanged,
updated,
added,
deleted []*discovery.Instance
}
func diff(new, old *discovery.Service) (ret *svcDiff) {
ret = &svcDiff{}
switch {
case new == nil && old != nil:
ret.deleted = old.Insts
return
case new != nil && old == nil:
ret.added = new.Insts
for _, inst := range ret.added {
if inst.Health == discovery.HealthPassing {
ret.healthy = append(ret.healthy, inst)
}
}
return
case new == nil || old == nil:
return
}
// find differences, Note that we know instances are sorted by ID
newN, oldN := len(new.Insts), len(old.Insts)
for newI, oldI := 0, 0; newI < newN && oldI < oldN; {
newInst, oldInst := new.Insts[newI], old.Insts[oldI]
switch {
case newInst.ID > oldInst.ID:
oldI++
ret.deleted = append(ret.deleted, oldInst)
case newInst.ID < oldInst.ID:
newI++
ret.added = append(ret.added, newInst)
default:
newI++
oldI++
if !reflect.DeepEqual(newInst, oldInst) {
ret.updated = append(ret.updated, newInst)
} else {
ret.unchanged = append(ret.unchanged, newInst)
}
}
}
for _, inst := range new.Insts {
if inst.Health == discovery.HealthPassing {
ret.healthy = append(ret.healthy, inst)
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consuldsync
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/dsync"
"github.com/cisco-open/go-lanai/pkg/utils/xsync"
"github.com/hashicorp/consul/api"
"sync"
"time"
)
const (
// lockFlagValue is a magic flag we set to indicate a key is being used for a lock.
// It is used to detect a potential conflict with a semaphore.
lockFlagValue = 0x275f2b610e0c3019
)
type ConsulLockOptions func(opt *ConsulLockOption)
type ConsulLockOption struct {
Context context.Context
SessionFunc func(context.Context) (string, error)
Key string // Must be set and have write permissions
Valuer dsync.LockValuer // cannot be nil, valuer to associate with the lock. Default to static json marshaller
QueryWaitTime time.Duration // how long we block per GET to check if lock acquisition is possible
RetryDelay time.Duration // how long we wait after a retryable error (usually network error)
}
type consulLockState int
const (
stateUnknown consulLockState = iota
stateAcquired
stateError
)
// ConsulLock implements Lock interface using consul lock described at https://www.consul.io/docs/guides/leader-election.html
// The implementation is modified api.Lock. The major difference are:
// - Session is created/maintained outside. There is no session creation when attempt to lock
// - "lock or wait" vs "try lock and return" is not pre-determined via options.
type ConsulLock struct {
mtx sync.Mutex
client *api.Client
option ConsulLockOption
// State Variables, requires mutex lock to read and write
loopContext context.Context
loopCancelFunc context.CancelFunc
lockLostCh chan struct{}
state consulLockState
stateCond *xsync.Cond
session string
refreshFunc context.CancelFunc // used when current acquisition should be stopped and restarted
lastErr error
}
func newConsulLock(client *api.Client, opts ...ConsulLockOptions) *ConsulLock {
ret := ConsulLock{
client: client,
option: ConsulLockOption{
Context: context.Background(),
QueryWaitTime: 10 * time.Minute,
RetryDelay: 2 * time.Second,
Valuer: dsync.NewJsonLockValuer(map[string]string{
"name": "consul distributed lock",
}),
},
}
// we start with a closed lost channel
ret.lockLostCh = make(chan struct{}, 1)
close(ret.lockLostCh)
ret.stateCond = xsync.NewCond(&ret.mtx)
for _, fn := range opts {
fn(&ret.option)
}
return &ret
}
func (l *ConsulLock) Key() string {
return l.option.Key
}
// Lock implements dsync.Lock
// The acquired lock may get revoked from server-side, unless the session is specifically created without any
// associated health checks.
func (l *ConsulLock) Lock(ctx context.Context) error {
l.lazyStart()
return l.waitForState(ctx, func(state consulLockState) (bool, error) {
switch {
case l.state == stateAcquired:
return true, nil
case l.loopContext == nil:
return true, context.Canceled
}
return false, nil
})
}
func (l *ConsulLock) TryLock(ctx context.Context) error {
l.lazyStart()
// TryLock differ from Lock that it also return on any error state
return l.waitForState(ctx, func(state consulLockState) (bool, error) {
switch {
case l.state == stateAcquired:
return true, nil
case l.state == stateError:
return true, l.lastErr
case l.loopContext == nil:
return true, context.Canceled
}
return false, nil
})
}
func (l *ConsulLock) Release() error {
return l.release()
}
func (l *ConsulLock) Lost() <-chan struct{} {
l.mtx.Lock()
defer l.mtx.Unlock()
return l.lockLostCh
}
func (l *ConsulLock) lazyStart() {
l.mtx.Lock()
defer l.mtx.Unlock()
// Check if we're already maintaining the lock loop
if l.loopContext == nil {
l.startLoop()
}
return
}
func (l *ConsulLock) waitForState(ctx context.Context, stateMatcher func(consulLockState) (bool, error)) error {
l.mtx.Lock()
defer l.mtx.Unlock()
for {
if ok, e := stateMatcher(l.state); ok {
return e
}
switch e := l.stateCond.Wait(ctx); {
case errors.Is(e, context.Canceled) || errors.Is(e, context.DeadlineExceeded):
return e
}
}
}
// updateState atomically update state, execute additional setters and broadcast the change.
// if given state < 0, only setters are executed
func (l *ConsulLock) updateState(s consulLockState, setters ...func()) {
l.mtx.Lock()
defer l.mtx.Unlock()
for _, fn := range setters {
fn()
}
if s < 0 {
return
}
if s == stateAcquired && l.state != s {
l.lockLostCh = make(chan struct{}, 1)
} else if l.state == stateAcquired && l.state != s {
close(l.lockLostCh)
}
if s == stateError || l.state != s {
defer l.stateCond.Broadcast()
}
l.state = s
}
// startLoop kickoff lock loop. mutex lock is required when call this function
func (l *ConsulLock) startLoop() {
l.loopContext, l.loopCancelFunc = context.WithCancel(l.option.Context)
go l.lockLoop(l.loopContext, l.loopCancelFunc)
}
// stopLoop stop lock loop. mutex lock is required when call this function
func (l *ConsulLock) stopLoop() {
l.loopCancelFunc()
l.loopContext = nil
l.loopCancelFunc = nil
}
// refresh is called by session manager to notify potential change of session ID
func (l *ConsulLock) refresh() {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.refreshFunc != nil {
l.refreshFunc()
}
}
// lockLoop is the main loop of attempting to maintain the lock.
// The lock state loop between Acquired and Error
// When unable to maintain the lock, the loop cancel the current context and try to lazyStart a new one
// Note: given context may also be cancelled outside, e.g. lock is released
func (l *ConsulLock) lockLoop(ctx context.Context, cancelFunc context.CancelFunc) {
defer cancelFunc()
LOOP:
for {
select {
case <-ctx.Done():
break LOOP
default:
}
// update refresh func, but keep the current state
refreshCtx, fn := context.WithCancel(ctx)
l.updateState(-1, func() { l.refreshFunc = fn })
// grab current session.
// Note: in case of error, we don't reset previously used session,
// the release function will try to release lock using previously used session
session, e := l.option.SessionFunc(refreshCtx)
switch {
case errors.Is(e, context.Canceled) || errors.Is(e, context.DeadlineExceeded):
// current acquisition is cancelled
continue
case e != nil:
l.updateState(stateError, func() { l.lastErr = dsync.ErrSessionUnavailable })
continue
default:
l.updateState(-1, func() { l.session = session })
}
// try to acquire lock
switch e := l.acquireLock(refreshCtx, session, 0); {
case errors.Is(e, context.Canceled) || errors.Is(e, context.DeadlineExceeded):
// current acquisition is cancelled
continue
case e == nil:
// lock acquired, continue
logger.WithContext(refreshCtx).Debugf("acquired lock [%s]", l.option.Key)
l.updateState(stateAcquired, func() { l.lastErr = nil })
default:
l.updateState(stateError, func() { l.lastErr = e })
continue
}
// up to this point, we have acquired the lock. enter monitor state
switch e := l.monitorLock(refreshCtx, session); {
case errors.Is(e, context.Canceled) || errors.Is(e, context.DeadlineExceeded):
// current acquisition is cancelled
continue
default:
// we lost the lock
logger.WithContext(refreshCtx).Debugf("lost lock [%s] - %v", l.option.Key, e)
l.updateState(stateError, func() { l.lastErr = e })
}
}
// we lost lock
l.updateState(stateUnknown)
}
func (l *ConsulLock) acquireLock(ctx context.Context, session string, maxWait time.Duration) error {
kv := l.client.KV()
pair := &api.KVPair{
Key: l.option.Key,
Value: l.option.Valuer(),
Session: session,
Flags: lockFlagValue,
}
waitUntilAvailable := maxWait <= 0
var waitCtx context.Context
if waitUntilAvailable {
waitCtx = ctx
} else {
var cancelFunc context.CancelFunc
waitCtx, cancelFunc = context.WithTimeout(ctx, maxWait)
defer cancelFunc()
}
LOOP:
for {
// try to acquire lock
switch acquired, _, e := kv.Acquire(pair, nil); {
case e != nil:
// we cannot acquire lock at the moment, possibly due to
// - network error
// - any 500 (e.g. session id is not valid)
l.delay(ctx, l.option.RetryDelay)
return fmt.Errorf("failed to acquire lock: %v", e)
case acquired:
break LOOP
}
// handle failure, might wait until lock become available and try again
switch current, e := l.handleAcquisitionFailure(waitCtx, session, waitUntilAvailable); {
case e != nil:
return e
case current == session:
break LOOP
case current != "" && !waitUntilAvailable:
return dsync.ErrLockUnavailable.WithMessage(`lock [%s] is held by another session`, l.option.Key)
}
// at this point, lock is not held by any session, but it may be in LockDelay period. pause and retry
if !l.delay(ctx, l.option.RetryDelay) {
return context.Canceled
}
}
// up to this point, we acquired the lock
return nil
}
// handleAcquisitionFailure handles lock acquisition failure. The provided ctx must be a cancellable context
// The function blocks until one of following condition is meet:
//
// 1. the provided context is cancelled or timed out
// 2. When waitUntilAvailable == true:
// 2.1 the lock becomes available (lock is not held any session)
// 2.2 the lock is held by its own session
// (this normally shouldn't happen, unless we attempt to recover previously held lock from network error)
// 3. When waitUntilAvailable == false:
// 3.1 current state of the lock become available (regardless if lock is available)
// 4. consul become unavailable
//
// Note: when this function returns, the lock might be in lock-delay period, meaning no session can acquire lock.
func (l *ConsulLock) handleAcquisitionFailure(ctx context.Context, session string, waitUntilAvailable bool) (currentOwner string, err error) {
kv := l.client.KV()
qOpts := (&api.QueryOptions{
WaitTime: l.option.QueryWaitTime,
}).WithContext(ctx)
for i := 0; true; i++ {
logger.WithContext(ctx).Debugf("wait attempt %d, WaitIndex=%d, WaitTime=%v", i, qOpts.WaitIndex, qOpts.WaitTime)
// Look for an existing lock and handle error. potentially blocking operation
pair, meta, e := kv.Get(l.option.Key, qOpts)
var owner string
switch {
case e != nil:
return "", fmt.Errorf("failed to read lock: %v", e)
case pair != nil && pair.Flags != lockFlagValue:
return "", api.ErrLockConflict
case pair != nil:
owner = pair.Session
}
// potentially retryable situations
switch {
case owner == "" || owner == session:
// the lock is held by current session OR the lock is not held by any session
return owner, nil
case !waitUntilAvailable:
return owner, nil
default:
// update error state and retry
l.updateState(stateError, func() { l.lastErr = dsync.ErrLockUnavailable.WithMessage(`lock [%s] is held by another session`, l.option.Key) })
}
// see if cancelled
select {
case <-ctx.Done():
return owner, context.Canceled
default:
}
// up to this point, we know the lock is held by other session, and context is not cancelled or timed out,
qOpts.WaitIndex = meta.LastIndex
}
return
}
// monitorLock is a long-running routine to monitor a lock ownership
// the function returns when given session lost ownership or cancelled (by refreshFunc)
func (l *ConsulLock) monitorLock(ctx context.Context, session string) error {
kv := l.client.KV()
opts := (&api.QueryOptions{
RequireConsistent: true,
}).WithContext(ctx)
var err error
LOOP:
for {
select {
case <-ctx.Done():
break LOOP
default:
}
pair, meta, e := kv.Get(l.option.Key, opts)
switch err = e; {
case e != nil && api.IsRetryableError(e):
// network error or something we can retry later
if l.delay(ctx, l.option.RetryDelay) {
opts.WaitIndex = 0
}
case e == nil && pair != nil && pair.Session == session:
// everything is fine, we enter long wait monitoring
opts.WaitIndex = meta.LastIndex
case e == nil:
// lock is lost, quit
err = fmt.Errorf("lock revoked by server")
break LOOP
default:
// other non-recoverable error, quit
break LOOP
}
}
if err == nil {
return context.Canceled
}
return err
}
// wait for given delay, return true if the delay is fulfilled (not cancelled by context)
func (l *ConsulLock) delay(ctx context.Context, delay time.Duration) (success bool) {
select {
case <-time.After(delay):
return true
case <-ctx.Done():
return false
}
}
func (l *ConsulLock) release() error {
// Hold the lock as we try to release
l.mtx.Lock()
defer l.mtx.Unlock()
// Ensure the lock is active
if l.loopContext == nil {
return nil
}
// Stop lock loop
l.stopLoop()
// Release the lock explicitly if previously used session is known
if l.session == "" {
return nil
}
pair := &api.KVPair{
Key: l.option.Key,
Session: l.session,
Flags: lockFlagValue,
}
_, _, err := l.client.KV().Release(pair, nil)
if err != nil {
return dsync.ErrUnlockFailed.WithCause(err)
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package consuldsync
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/consul"
"github.com/cisco-open/go-lanai/pkg/dsync"
"github.com/cisco-open/go-lanai/pkg/utils/xsync"
"github.com/hashicorp/consul/api"
"sync"
"time"
)
// ConsulSyncManager implements SyncManager leveraging consul's session feature
// See https://learn.hashicorp.com/tutorials/consul/application-leader-elections?in=consul/developer-configuration
// https://learn.hashicorp.com/tutorials/consul/distributed-semaphore
// https://www.consul.io/docs/dynamic-app-config/sessions
type ConsulSyncManager struct {
mtx sync.Mutex
appCtx *bootstrap.ApplicationContext
client *api.Client
option ConsulSessionOption
shutdown bool
session string
sessionCond *xsync.Cond
cancelFunc context.CancelFunc
locks map[string]*ConsulLock
}
type ConsulSessionOptions func(opt *ConsulSessionOption)
type ConsulSessionOption struct {
Name string
TTL time.Duration
LockDelay time.Duration
RetryDelay time.Duration
}
func NewConsulLockManager(ctx *bootstrap.ApplicationContext, conn *consul.Connection, opts ...ConsulSessionOptions) (ret *ConsulSyncManager) {
ret = &ConsulSyncManager{
appCtx: ctx,
client: conn.Client(),
option: ConsulSessionOption{
Name: fmt.Sprintf("%s", ctx.Name()),
TTL: 10 * time.Second,
LockDelay: 2 * time.Second,
RetryDelay: 2 * time.Second,
},
locks: make(map[string]*ConsulLock),
}
ret.sessionCond = xsync.NewCond(&ret.mtx)
for _, fn := range opts {
fn(&ret.option)
}
return
}
func (m *ConsulSyncManager) Start(_ context.Context) error {
// do nothing, we support lazy start
return nil
}
func (m *ConsulSyncManager) Stop(ctx context.Context) error {
m.mtx.Lock()
defer m.mtx.Unlock()
// enter shutdown mode
m.shutdown = true
// stopLoop session loop
m.stopLoop()
// release all existing locks
for k, l := range m.locks {
if e := l.Release(); e != nil {
logger.WithContext(ctx).Warnf("Failed to release lock [%s]: %v", k, e)
}
}
return nil
}
func (m *ConsulSyncManager) Lock(key string, opts ...dsync.LockOptions) (dsync.Lock, error) {
if key == "" {
return nil, fmt.Errorf(`cannot create distributed lock: key is required but missing`)
}
option := dsync.LockOption{
Valuer: dsync.NewJsonLockValuer(map[string]string{
"name": fmt.Sprintf("distributed lock - %s", m.appCtx.Name()),
}),
}
for _, fn := range opts {
fn(&option)
}
m.mtx.Lock()
defer m.mtx.Unlock()
if m.shutdown {
return nil, dsync.ErrSyncManagerStopped
} else if lock, ok := m.locks[key]; ok {
return lock, nil
}
m.locks[key] = newConsulLock(m.client, func(opt *ConsulLockOption) {
opt.Context = m.appCtx
opt.SessionFunc = m.waitForSession
opt.Key = key
opt.Valuer = option.Valuer
})
return m.locks[key], nil
}
// startLoop requires mutex lock
func (m *ConsulSyncManager) startLoop() error {
if m.shutdown {
return dsync.ErrSyncManagerStopped
}
if m.cancelFunc == nil {
ctx, cf := context.WithCancel(m.appCtx)
m.cancelFunc = cf
go m.sessionLoop(ctx)
}
return nil
}
// stopLoop requires mutex lock
func (m *ConsulSyncManager) stopLoop() {
if m.cancelFunc != nil {
m.cancelFunc()
m.cancelFunc = nil
}
return
}
// waitForSession returns current managed session. It blocks until session is available or given context is cancelled
func (m *ConsulSyncManager) waitForSession(ctx context.Context) (string, error) {
m.mtx.Lock()
defer m.mtx.Unlock()
//nolint:contextcheck // startLoop starts background goroutine, it needs ApplicationContext to ensure the loop is not cancelled by given ctx
if e := m.startLoop(); e != nil {
return "", e
}
for {
switch m.session {
case "":
if e := m.sessionCond.Wait(ctx); e != nil {
return "", e
}
default:
return m.session, nil
}
}
}
func (m *ConsulSyncManager) updateSession(sid string) {
m.mtx.Lock()
defer m.mtx.Unlock()
defer func(from, to string) {
if from != to {
m.sessionCond.Broadcast()
}
}(m.session, sid)
m.session = sid
}
// sessionLoop is the main loop to manage session
func (m *ConsulSyncManager) sessionLoop(ctx context.Context) {
for {
select {
case <-ctx.Done():
logger.WithContext(ctx).Infof("sync manager stopped")
return
default:
}
// reset session
m.updateSession("")
wOpts := (*api.WriteOptions)(nil).WithContext(ctx)
session, _, e := m.client.Session().Create(&api.SessionEntry{
Name: m.option.Name,
TTL: m.option.TTL.String(),
LockDelay: m.option.LockDelay,
Behavior: "delete",
}, wOpts)
switch e {
case nil:
m.updateSession(session)
default:
select {
case <-time.After(m.option.RetryDelay):
continue
case <-ctx.Done():
continue
}
}
// keep renewing the session
_ = m.keepSession(ctx, session)
// session is invalid/expired by this point.
// try to notify all existing locks
m.mtx.Lock()
for _, l := range m.locks {
l.refresh()
}
m.mtx.Unlock()
}
}
func (m *ConsulSyncManager) keepSession(ctx context.Context, session string) error {
for {
select {
case <-ctx.Done():
return context.Canceled
default:
}
// RenewPeriodic is used to periodically invoke Session.Renew on a
// session until a doneChan is closed. This is meant to be used in a long-running
// goroutine to ensure a session stays valid.
wOpts := (*api.WriteOptions)(nil).WithContext(ctx)
e := m.client.Session().RenewPeriodic(m.option.TTL.String(), session, wOpts, ctx.Done())
switch {
case e == nil:
// just continue
case errors.Is(e, api.ErrSessionExpired):
logger.WithContext(ctx).Warnf("session expired")
return e
default:
logger.WithContext(ctx).Warnf("session lost: %v", e)
return e
}
}
}
package consuldsync
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/consul"
"github.com/cisco-open/go-lanai/pkg/dsync"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
var logger = log.New("DSync")
var Module = &bootstrap.Module{
Name: "distributed",
Precedence: bootstrap.DistributedLockPrecedence,
Options: []fx.Option{
fx.Provide(provideSyncManager),
},
Modules: []*bootstrap.Module{dsync.Module},
}
func Use() {
bootstrap.Register(Module)
}
/**************************
Provider
***************************/
type syncDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Conn *consul.Connection `optional:"true"`
}
func provideSyncManager(di syncDI) (dsync.SyncManager, error) {
if di.Conn == nil {
return nil, fmt.Errorf("*consul.Connection is required for 'dsync' package")
}
return NewConsulLockManager(di.AppCtx, di.Conn), nil
}
/**************************
Initialize
***************************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package dsync
// Provides distributed synchronization support of microservices and provide common usage patterns
// around distributed lock, such as lock-based service leader election.
package dsync
import (
"context"
"encoding/json"
"fmt"
)
// FxGroup is a group name for uber.fx
const FxGroup = `dsync`
var (
ErrLockUnavailable = newError("lock is held by another session")
ErrUnlockFailed = newError("failed to release lock")
ErrSessionUnavailable = newError("session is not available")
ErrSyncManagerStopped = newError("sync manager stopped")
ErrFailedInitialization = newError("sync manager failed to start")
)
// SyncManager manage distributed locks across the application.
type SyncManager interface {
// Lock returns a distributed lock with given key. If the Lock already exists with same key,
// the options are ignored and the same Lock is returned.
//
// The returned Lock is goroutines-safe, but locking/releasing same lock from different goroutine may cause
// complicated scenarios. It's application's responsibility to coordinate such concurrent usage.
Lock(key string, opts ...LockOptions) (Lock, error)
}
type SyncManagerLifecycle interface {
Start(ctx context.Context) error
Stop(ctx context.Context) error
}
type LockOptions func(opt *LockOption)
type LockOption struct {
Valuer LockValuer
}
// LockValuer is used to annotate the lock in external infra service.
// It's treated literally and serves as lock's metadata
type LockValuer func() []byte
// Lock distributed mutex lock backed by external infrastructure service such as consul or redis.
// Once lock acquisition is started (Lock.Lock or Lock.TryLock), regardless the result, the Lock would keep trying
// to acquire/re-acquire the lock until Lock.Release is manually invoked, because the lock might be revoked by operator
// or external infra service.
//
// Long-running goroutine should monitor Lost channel after the lock is acquired.
// When Lost channel is signalled, there is no need to re-invoke Lock.Lock or Lock.TryLock, since internal loop would try
// to re-acquire lock. However, any existing tasks relying on this lock should be stopped because there is no guarantee
// that the lock will be re-acquired
type Lock interface {
// Key the unique identifier of the lock
Key() string
// Lock attempts to acquire the lock and blocks until lock is acquired or context is cancelled/timed out.
// Invoking Lock after lock is acquired (or re-acquired after some error) returns immediately.
//
// A cancellable context.Context can be used to abort the current attempt, but it won't stop the lock to keep
// trying in the background.
//
// It is NOT safe to assume that the lock is guaranteed to be held until Release(). The lock might be lost
// due to session invalidation, communication errors, operator intervention, etc.
//
// Lost() returns a channel that is closed if our lock is lost or an error occurred.
// By default, dsync implementations prefer liveness over safety and an application must be able to handle
// the lock being lost.
//
// Important: Regardless the result, the lock would keep trying to acquire the lock in the background.
// So a pairing call of Release() is always required after the lock is no longer needed, even if the context is canceled
Lock(ctx context.Context) error
// TryLock differs from Lock in following ways:
// - TryLock stop blocking when lock is held by other instance/session
// - TryLock stop blocking when unrecoverable error happens during lock acquisition
// Note: TryLock may temporarily block when connectivity to external infra service is not available
//
// Important: Regardless the result, the lock would keep trying to acquire the lock in the background.
// So a pairing call of Release() is always required after the lock is no longer needed, even if the context is canceled
TryLock(ctx context.Context) error
// Release stops the attempt to acquire the lock and releases the lock if already held
// Release must be used everytime after Lock or TryLock is called, unless the application is intended
// to hold the lock indefinitely.
//
// Invoking Release multiple time takes no effect.
//
// Note: Lost channel would stop signalling after Release, until Lock or TryLock is called again.
Release() error
// Lost channel signals long-running goroutine when lock is lost (due to network error, operator intervention,
// manual Release() call from other goroutine, etc).
//
// When Lost channel is signalled, there is no need to re-invoke Lock.Lock or Lock.TryLock for lock re-acquisition
// unless it's caused by manual Release() call, but all relying-tasks should pause.
Lost() <-chan struct{}
}
/*********************
Common Impl
*********************/
// LockWithKey returns a distributed Lock with given key. If the Lock already exists with same key,
// the options are ignored and the same Lock is returned.
//
// The returned Lock is goroutines-safe, but locking/releasing same lock from different goroutine may cause
// complicated scenarios. It's application's responsibility to coordinate such concurrent usage.
//
// This function panic if internal SyncManager is not initialized yet or key is not provided.
func LockWithKey(key string, opts ...LockOptions) Lock {
if syncManager == nil {
panic("SyncManager is not initialized")
}
l, e := syncManager.Lock(key, opts...)
if e != nil {
panic(e)
}
return l
}
// NewJsonLockValuer is the default implementation of LockValuer.
func NewJsonLockValuer(v interface{}) LockValuer {
return func() []byte {
data, e := json.Marshal(v)
if e != nil {
return []byte(fmt.Sprintf(`"marshalling error: %v"`, e))
}
return data
}
}
package dsync
import (
"errors"
"fmt"
)
var errorTypeCounter int
type comparableError struct {
typ int
msg string
cause error
}
func newError(tmpl string, args...interface{}) comparableError {
errorTypeCounter ++
return comparableError {
typ: errorTypeCounter,
msg: fmt.Sprintf(tmpl, args...),
}
}
func (w comparableError) Error() string {
return w.msg
}
func (w comparableError) Is(target error) bool {
var comparableTarget comparableError
if errors.As(target, &comparableTarget) {
return w.typ == comparableTarget.typ
}
return errors.Is(w.cause, target)
}
func (w comparableError) Unwrap() error {
return w.cause
}
func (w comparableError) WithMessage(tmpl string, args ...interface{}) comparableError {
return comparableError{
typ: w.typ,
msg: fmt.Sprintf(tmpl, args...),
cause: w.cause,
}
}
func (w comparableError) WithCause(err error) comparableError {
return comparableError{
typ: w.typ,
msg: fmt.Sprintf(`%s: %v`, w.msg, err),
cause: err,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dsync
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"sync"
)
const (
leadershipLockKeyFormat = "service/%s/leadership"
)
var (
leadershipOnce sync.Once
leadershipLock Lock
)
// LeadershipLock returns globally maintained lock for leadership election
// To check leadership, use Lock.TryLock and check error.
// Example:
// if e := LeadershipLock().TryLock(ctx); e == nil {
// // do what a leader should do
// }
//
// This function panic if it's call too soon during startup
// Note: Lock.Lost() channel should be monitored for long-running goroutine, since leadership could be revoked any time by operators
func LeadershipLock() Lock {
if leadershipLock == nil {
panic("Leadership Lock is not initialized yet")
}
return leadershipLock
}
func startLeadershipLock(_ context.Context, di initDI) (err error) {
//nolint:contextcheck // we want di.AppCtx on purpose
leadershipOnce.Do(func() {
leadershipLock, err = syncManager.Lock(
fmt.Sprintf(leadershipLockKeyFormat, di.AppCtx.Name()),
func(opt *LockOption) {
opt.Valuer = leaderLockValuer(di.AppCtx)
},
)
if err != nil {
return
}
// Note we don't care the lock result, as long as we tell the lock to keep trying.
// This goroutine is for logging purpose
go func() {
LOOP:
for {
if e := leadershipLock.Lock(di.AppCtx); e == nil {
logger.WithContext(di.AppCtx).Infof("Leadership - become leader [%s]", leadershipLock.Key())
select {
case <-leadershipLock.Lost():
logger.WithContext(di.AppCtx).Infof("Leadership - lost [%s]", leadershipLock.Key())
case <-di.AppCtx.Done():
}
}
select {
case <-di.AppCtx.Done():
break LOOP
default:
}
}
}()
})
return
}
func leaderLockValuer(appCtx *bootstrap.ApplicationContext) LockValuer {
return NewJsonLockValuer(map[string]interface{}{
"service": map[string]interface{}{
"name": appCtx.Name(),
"port": appCtx.Config().Value("server.port"),
"context-path": appCtx.Config().Value("server.context-path"),
},
"build": bootstrap.BuildInfoMap,
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dsync
import (
"context"
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
//go:embed defaults-dsync.yml
var defaultConfigFS embed.FS
var logger = log.New("DSync")
var syncManager SyncManager
var Module = &bootstrap.Module{
Name: "distributed",
Precedence: bootstrap.DistributedLockPrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Invoke(initialize),
},
}
/**************************
Provider
***************************/
/**************************
Initialize
***************************/
type initDI struct {
fx.In
Lifecycle fx.Lifecycle
AppCtx *bootstrap.ApplicationContext
Manager SyncManager `optional:"true"`
ManagerOverrides []SyncManager `group:"dsync"`
}
func initialize(di initDI) error {
// set global variable
syncManager = di.Manager
if len(di.ManagerOverrides) != 0 {
syncManager = di.ManagerOverrides[0]
}
if syncManager == nil {
return ErrFailedInitialization.WithMessage(`unable to initialize distributed lock system and leadership lock. ` +
`Hint: provide a dsync.SyncManager with 'consuldsync.Use()' or 'redisdsync.Use()' or with your own implementation `)
}
syncLc, ok := syncManager.(SyncManagerLifecycle)
// start/stop hooks
di.Lifecycle.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
if ok {
if e := syncLc.Start(ctx); e != nil {
return ErrFailedInitialization.WithCause(e)
}
}
// start leader election lock
return startLeadershipLock(ctx, di)
},
OnStop: func(ctx context.Context) error {
if ok {
return syncLc.Stop(ctx)
}
return nil
},
})
return nil
}
package redisdsync
import (
"context"
"encoding/json"
"errors"
"github.com/cisco-open/go-lanai/pkg/dsync"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/xsync"
"github.com/go-redsync/redsync/v4"
"sync"
"time"
)
type lockState int
const (
stateUnknown lockState = iota
stateAcquired
stateError
)
type RedisLockOptions func(opt *RedisLockOption)
type RedisLockOption struct {
Context context.Context
Name string
Valuer dsync.LockValuer
// AutoExpiry how long the acquired lock expires (released) in case the application crashes.
// It's recommended to keep this value larger than 5 seconds
// Default is 10 seconds
AutoExpiry time.Duration
// RetryDelay how long we wait after a retryable error (usually network error)
// Default is 500 milliseconds
RetryDelay time.Duration
// TimeoutFactor used to calculate redis CMD timeout when acquiring, extending and releasing lock.
// timeout = AutoExpiry * TimeoutFactor
// Note: the value should be smaller than 0.5 and recommended to be between 0.01 to 0.1 depending on the AutoExpiry.
// Default is 0.05
TimeoutFactor float64
// MaxExtendRetries how many times we attempt to extend the lock before give up.
// Default is 3
MaxExtendRetries int
}
type RedisLock struct {
mtx sync.Mutex
rsMutex *redsync.Mutex
option RedisLockOption
// State Variables, requires mutex lock to read and write
loopContext context.Context
loopCancelFunc context.CancelFunc
lockLostCh chan struct{}
state lockState
stateCond *xsync.Cond
lastErr error
}
func newRedisLock(rs *redsync.Redsync, opts ...RedisLockOptions) (lock *RedisLock) {
opt := RedisLockOption{
Valuer: dsync.NewJsonLockValuer(map[string]string{
"name": "redis distributed lock",
}),
AutoExpiry: 10 * time.Second,
RetryDelay: 500 * time.Millisecond,
MaxExtendRetries: 3,
TimeoutFactor: 0.05,
}
for _, fn := range opts {
fn(&opt)
}
// Note: we only use TryLock and perform indefinite retries, so WithTries is set to 1 in order get proper error
// See redsync.Mutex.TryLockContext for details
rsMutex := rs.NewMutex(opt.Name,
redsync.WithExpiry(opt.AutoExpiry),
redsync.WithTries(1),
redsync.WithTimeoutFactor(opt.TimeoutFactor),
redsync.WithShufflePools(true),
redsync.WithGenValueFunc(genValueFunc(opt.Valuer)),
)
// we start with a closed lost channel
defer func() {
lock.lockLostCh = make(chan struct{}, 1)
close(lock.lockLostCh)
lock.stateCond = xsync.NewCond(&lock.mtx)
}()
return &RedisLock{
rsMutex: rsMutex,
option: opt,
}
}
func (l *RedisLock) Key() string {
return l.rsMutex.Name()
}
func (l *RedisLock) Lock(ctx context.Context) error {
l.lazyStart()
return l.waitForState(ctx, func(state lockState) (bool, error) {
switch {
case l.state == stateAcquired:
return true, nil
case l.loopContext == nil:
return true, context.Canceled
}
return false, nil
})
}
func (l *RedisLock) TryLock(ctx context.Context) error {
l.lazyStart()
// TryLock differ from Lock that it also return on any error state
return l.waitForState(ctx, func(state lockState) (bool, error) {
switch {
case l.state == stateAcquired:
return true, nil
case l.state == stateError:
return true, l.lastErr
case l.loopContext == nil:
return true, context.Canceled
}
return false, nil
})
}
func (l *RedisLock) Release() error {
return l.release()
}
func (l *RedisLock) Lost() <-chan struct{} {
l.mtx.Lock()
defer l.mtx.Unlock()
return l.lockLostCh
}
func (l *RedisLock) lazyStart() {
l.mtx.Lock()
defer l.mtx.Unlock()
// Check if we're already maintaining the lock loop
if l.loopContext == nil {
l.startLoop()
}
return
}
func (l *RedisLock) waitForState(ctx context.Context, stateMatcher func(lockState) (bool, error)) error {
l.mtx.Lock()
defer l.mtx.Unlock()
for {
if ok, e := stateMatcher(l.state); ok {
return e
}
switch e := l.stateCond.Wait(ctx); {
case errors.Is(e, context.Canceled) || errors.Is(e, context.DeadlineExceeded):
return e
}
}
}
// updateState atomically update state, execute additional setters and broadcast the change.
// if given state < 0, only setters are executed
func (l *RedisLock) updateState(s lockState, setters ...func()) {
l.mtx.Lock()
defer l.mtx.Unlock()
for _, fn := range setters {
fn()
}
if s < 0 {
return
}
if s == stateAcquired && l.state != s {
l.lockLostCh = make(chan struct{}, 1)
} else if l.state == stateAcquired && l.state != s {
close(l.lockLostCh)
}
if s == stateError || l.state != s {
defer l.stateCond.Broadcast()
}
l.state = s
}
// startLoop kickoff lock loop. mutex lock is required when call this function
func (l *RedisLock) startLoop() {
l.loopContext, l.loopCancelFunc = context.WithCancel(l.option.Context)
go l.lockLoop(l.loopContext, l.loopCancelFunc)
}
// stopLoop stop lock loop. mutex lock is required when call this function
func (l *RedisLock) stopLoop() {
if l.loopCancelFunc != nil {
l.loopCancelFunc()
}
l.loopContext = nil
l.loopCancelFunc = nil
}
// lockLoop is the main loop of attempting to maintain the lock.
// The lock state loop between Acquired and Error
// When unable to maintain the lock, the loop cancel the current context and try to lazyStart a new one
// Note: given context may also be cancelled outside, e.g. lock is released
func (l *RedisLock) lockLoop(ctx context.Context, cancelFunc context.CancelFunc) {
defer cancelFunc()
defer func() {
// we've quited the loop, need some cleaning up:
// 1. in case the lock is still locked (e.g. context canceled after lock is acquired), we need to explicitly release lock.
_, _ = l.rsMutex.Unlock()
l.updateState(stateUnknown)
}()
LOOP:
for {
select {
case <-ctx.Done():
break LOOP
default:
}
// try to acquire lock
switch e := l.rsMutex.TryLockContext(ctx); {
case errors.Is(e, context.Canceled) || errors.Is(e, context.DeadlineExceeded):
// current acquisition is cancelled
continue
case e == nil:
// lock acquired, continue
logger.WithContext(ctx).Debugf("acquired lock [%s]", l.option.Name)
l.updateState(stateAcquired, func() { l.lastErr = nil })
default:
l.updateState(stateError, func() {
l.lastErr = dsync.ErrLockUnavailable.WithMessage(`lock [%s] is held by another session`, l.option.Name).WithCause(e)
})
l.delay(ctx, l.option.RetryDelay)
continue
}
// up to this point, we have acquired the lock. enter monitor state
switch e := l.monitorLock(ctx); {
case errors.Is(e, context.Canceled) || errors.Is(e, context.DeadlineExceeded):
// current acquisition is cancelled
continue
default:
// we lost the lock
logger.WithContext(ctx).Debugf("lost lock [%s] - %v", l.option.Name, e)
l.updateState(stateError, func() { l.lastErr = e })
}
}
}
// monitorLock is a long-running routine to monitor a lock ownership and try to extend lock periodically
// the function returns when given context is cancelled or lock lost ownership
func (l *RedisLock) monitorLock(ctx context.Context) error {
var err error
var failedAttempts int
timeout := time.Duration(float64(l.option.AutoExpiry) * l.option.TimeoutFactor)
LOOP:
for {
var waitForExpiry bool
expire := time.Until(l.rsMutex.Until())
wait := expire / 2
// Check if we have enough time to extend it. If not, we enter "wait for expiry" mode
if wait < timeout || failedAttempts >= l.option.MaxExtendRetries {
waitForExpiry = true
wait = expire
}
if ok := l.delay(ctx, wait); !ok || waitForExpiry {
break LOOP
}
// regardless the result, lock is not lost yet, if we cannot extend it now, we will try it later
ok, e := l.rsMutex.ExtendContext(ctx)
switch err = e; {
case e == nil && !ok:
err = dsync.ErrLockUnavailable.WithMessage(`failed to extend lock with unknown reason`)
}
if err != nil {
failedAttempts++
logger.WithContext(ctx).Debugf(e.Error())
}
}
if err == nil {
return context.Canceled
}
return err
}
// wait for given delay, return true if the delay is fulfilled (not cancelled by context)
func (l *RedisLock) delay(ctx context.Context, delay time.Duration) (success bool) {
select {
case <-time.After(delay):
return true
case <-ctx.Done():
return false
}
}
func (l *RedisLock) release() error {
// Hold the lock as we try to release
l.mtx.Lock()
defer l.mtx.Unlock()
// Ensure the lock is active
if l.loopContext == nil {
return nil
}
// Stop lock loop. Releasing the lock happens in the loop
l.stopLoop()
return nil
}
/***********************
helpers
***********************/
type lockValue struct {
Metadata interface{} `json:"metadata"`
Token string `json:"token"`
}
func genValueFunc(valuer dsync.LockValuer) func() (string, error) {
// attempt to parse metadata as JSON
var value []byte
if valuer != nil {
value = valuer()
}
var meta interface{}
if e := json.Unmarshal(value, &meta); e != nil {
meta = value
}
v := lockValue{
Metadata: meta,
Token: utils.RandomString(16),
}
data, e := json.Marshal(v)
if e != nil {
data = []byte(v.Token)
}
return func() (string, error) {
return string(data), nil
}
}
package redisdsync
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/dsync"
redislib "github.com/go-redis/redis/v8"
"github.com/go-redsync/redsync/v4"
redsyncredis "github.com/go-redsync/redsync/v4/redis"
"github.com/go-redsync/redsync/v4/redis/goredis/v8"
"sync"
"time"
)
type RedisSyncOptions func(opt *RedisSyncOption)
type RedisSyncOption struct {
// Clients are go-redis/v8 clients.
// Each client should be able to connect to an independent Redis master/cluster/sentinel-master to form quorum
Clients []redislib.UniversalClient
// TTL see RedisLockOption.AutoExpiry
TTL time.Duration
// RetryDelay see RedisLockOption.RetryDelay
RetryDelay time.Duration
// TimeoutFactor see RedisLockOption.TimeoutFactor
TimeoutFactor float64
}
func NewRedisSyncManager(appCtx *bootstrap.ApplicationContext, opts ...RedisSyncOptions) *RedisSyncManager {
opt := RedisSyncOption{
TTL: 10 * time.Second,
RetryDelay: 1 * time.Second,
TimeoutFactor: 0.05,
}
for _, fn := range opts {
fn(&opt)
}
pools := make([]redsyncredis.Pool, len(opt.Clients))
for i := range opt.Clients {
pools[i] = goredis.NewPool(opt.Clients[i])
}
return &RedisSyncManager{
appCtx: appCtx,
options: opt,
syncer: redsync.New(pools...),
locks: make(map[string]*RedisLock),
}
}
type RedisSyncManager struct {
appCtx *bootstrap.ApplicationContext
options RedisSyncOption
mtx sync.Mutex
syncer *redsync.Redsync
locks map[string]*RedisLock
}
func (m *RedisSyncManager) Lock(key string, opts ...dsync.LockOptions) (dsync.Lock, error) {
if key == "" {
return nil, fmt.Errorf(`cannot create distributed lock: key is required but missing`)
}
opt := dsync.LockOption{
Valuer: dsync.NewJsonLockValuer(map[string]string{
"name": fmt.Sprintf("distributed lock - %s", m.appCtx.Name()),
}),
}
for _, fn := range opts {
fn(&opt)
}
m.mtx.Lock()
defer m.mtx.Unlock()
if lock, ok := m.locks[key]; ok {
return lock, nil
}
m.locks[key] = newRedisLock(m.syncer, func(opt *RedisLockOption) {
opt.Context = m.appCtx
opt.Name = key
opt.AutoExpiry = m.options.TTL
opt.RetryDelay = m.options.RetryDelay
opt.TimeoutFactor = m.options.TimeoutFactor
})
return m.locks[key], nil
}
func (m *RedisSyncManager) Start(_ context.Context) error {
return nil
}
func (m *RedisSyncManager) Stop(_ context.Context) error {
m.mtx.Lock()
defer m.mtx.Unlock()
var failed []string
for k, lock := range m.locks {
if e := lock.Release(); e != nil {
failed = append(failed, k)
}
}
if len(failed) > 0 {
return dsync.ErrUnlockFailed.WithMessage(`unable to release locks %v`, failed)
}
return nil
}
package redisdsync
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/dsync"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/redis"
redislib "github.com/go-redis/redis/v8"
"go.uber.org/fx"
)
var logger = log.New("DSync")
var Module = &bootstrap.Module{
Name: "distributed",
Precedence: bootstrap.DistributedLockPrecedence,
Options: []fx.Option{
fx.Provide(provideSyncManager),
},
Modules: []*bootstrap.Module{dsync.Module},
}
func Use() {
bootstrap.Register(Module)
}
/**************************
Provider
***************************/
type syncDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
RedisFactory redis.ClientFactory `optional:"true"`
RedisClients []redislib.UniversalClient `group:"dsync"`
}
func provideSyncManager(di syncDI) (dsync.SyncManager, error) {
var clients []redislib.UniversalClient
switch {
case len(di.RedisClients) != 0:
clients = append(clients, di.RedisClients...)
case di.RedisFactory != nil:
client, e := di.RedisFactory.New(di.AppCtx, func(cOpt *redis.ClientOption) {
cOpt.DbIndex = 1
})
if e != nil {
return nil, dsync.ErrSyncManagerStopped.WithMessage("unable to initialize Redis SyncManager").WithCause(e)
}
clients = []redislib.UniversalClient{client}
default:
return nil, fmt.Errorf(`redis.ClientFactory or []go-redis/redis/v8.UniversalClient with FX group '%s' are required for 'redisdsync' package`, dsync.FxGroup)
}
return NewRedisSyncManager(di.AppCtx, func(opt *RedisSyncOption) {
opt.Clients = clients
}), nil
}
package httpclient
import (
"fmt"
"sync/atomic"
)
type balancer[T any] interface {
Balance([]T) (T, error)
}
func newRoundRobinBalancer[T any]() balancer[T] {
return &roundRobin[T]{}
}
type roundRobin[T any] struct {
c uint64
}
func (rr *roundRobin[T]) Balance(values []T) (v T, err error) {
if len(values) <= 0 {
return v, NewNoEndpointFoundError(fmt.Errorf("cannot find service"))
}
old := atomic.AddUint64(&rr.c, 1) - 1
idx := old % uint64(len(values))
return values[idx], nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"net/url"
"time"
)
var (
httpInstanceMatcher = discovery.InstanceWithTagKV("secure", "false", true).
Or(discovery.InstanceWithTagKV("insecure", "true", true)).
Or(discovery.InstanceWithMetaKV("scheme", "http"))
httpsInstanceMatcher = discovery.InstanceWithTagKV("secure", "true", true).
Or(discovery.InstanceWithTagKV("insecure", "false", true)).
Or(discovery.InstanceWithMetaKV("scheme", "https"))
supportedSchemes = utils.NewStringSet("http", "https")
)
type clientDefaults struct {
selector discovery.InstanceMatcher
before []BeforeHook
after []AfterHook
}
type client struct {
defaults *clientDefaults
config *ClientConfig
sdClient discovery.Client
before []BeforeHook
after []AfterHook
resolver TargetResolver
}
func NewClient(opts ...ClientOptions) Client {
config := DefaultConfig()
opt := ClientOption{
ClientConfig: *config,
DefaultSelector: discovery.InstanceIsHealthy(),
DefaultBeforeHooks: []BeforeHook{HookRequestLogger(config)},
DefaultAfterHooks: []AfterHook{HookResponseLogger(config)},
}
for _, f := range opts {
f(&opt)
}
ret := &client{
config: &opt.ClientConfig,
sdClient: opt.SDClient,
defaults: &clientDefaults{
selector: opt.DefaultSelector,
before: opt.DefaultBeforeHooks,
after: opt.DefaultAfterHooks,
},
}
ret.updateConfig(&opt.ClientConfig)
return ret
}
func (c *client) WithService(service string, opts ...SDOptions) (Client, error) {
if c.sdClient == nil {
return nil, NewNoEndpointFoundError("cannot create client with service name: service discovery client is not configured")
}
instancer, e := c.sdClient.Instancer(service)
if e != nil {
return nil, NewNoEndpointFoundError(fmt.Errorf("cannot create client with service name: %s", service), e)
}
defaultOpts := func(opts *SDOption) {
opts.Selector = c.defaults.selector
opts.InvalidateOnError = true
}
opts = append([]SDOptions{defaultOpts}, opts...)
targetResolver, e := NewSDTargetResolver(instancer, opts...)
if e != nil {
return nil, NewNoEndpointFoundError(fmt.Errorf("cannot create client with service name: %s", service), e)
}
cp := c.shallowCopy()
cp.resolver = targetResolver
return cp.WithConfig(defaultServiceConfig()), nil
}
func (c *client) WithBaseUrl(baseUrl string) (Client, error) {
endpointer, e := NewStaticTargetResolver(baseUrl)
if e != nil {
return nil, NewNoEndpointFoundError(fmt.Errorf("cannot create client with base URL: %s", baseUrl), e)
}
cp := c.shallowCopy()
cp.resolver = endpointer
return cp.WithConfig(defaultExtHostConfig()), nil
}
func (c *client) WithConfig(config *ClientConfig) Client {
mergeConfig(config, c.config)
cp := c.shallowCopy()
cp.updateConfig(config)
return cp
}
func (c *client) Execute(ctx context.Context, request *Request, opts ...ResponseOptions) (ret *Response, err error) {
// apply options
opt := responseOption{}
for _, f := range opts {
f(&opt)
}
// apply fallback options
fallbackResponseOptions(&opt)
// execute
executor := c.executor(request, c.resolver, opt.decodeFunc)
retryCB := c.config.RetryCallback
if retryCB == nil {
retryCB = c.retryCallback()
}
resp, e := executor.Try(ctx, c.config.Timeout, retryCB)
if e != nil {
err = c.translateError(request, e)
}
// return result
switch v := resp.(type) {
case *Response:
ret = v
case Response:
ret = &v
default:
if err == nil {
err = NewInternalError(fmt.Errorf("expected a *Response, but HTTP response decode function returned %T", resp))
}
}
return
}
// retryCallback is a retry control func.
// It keep trying in case that error is not ErrorTypeResponse and not reached max value
func (c *client) retryCallback() RetryCallback {
return func(n int, rs interface{}, err error) (bool, time.Duration) {
return n < c.config.MaxRetries && !errors.Is(err, ErrorTypeResponse), c.config.RetryBackoff
}
}
func (c *client) translateError(req *Request, err error) (ret *Error) {
switch {
case errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded):
e := fmt.Errorf("remote HTTP call [%s] %s timed out after %v", req.Method, req.Path, c.config.Timeout)
return NewServerTimeoutError(e)
case errors.Is(err, ErrorSubTypeDiscovery):
errors.As(err, &ret)
return ret.WithMessage("remote HTTP call [%s] %s: no endpoints available", req.Method, req.Path)
case errors.Is(err, ErrorCategoryHttpClient):
errors.As(err, &ret)
return
default:
e := fmt.Errorf("uncategrized remote HTTP call [%s] %s error: %v", req.Method, req.Path, err)
return NewInternalError(e)
}
}
func (c *client) updateConfig(config *ClientConfig) {
c.config = config
c.before = make([]BeforeHook, len(c.defaults.before)+len(config.BeforeHooks))
copy(c.before, c.defaults.before)
copy(c.before[len(c.defaults.before):], config.BeforeHooks)
for i := range c.before {
if configurable, ok := c.before[i].(ConfigurableBeforeHook); ok {
c.before[i] = configurable.WithConfig(config)
}
}
order.SortStable(c.before, order.OrderedFirstCompare)
c.after = make([]AfterHook, len(c.defaults.after)+len(config.AfterHooks))
copy(c.after, c.defaults.after)
copy(c.after[len(c.defaults.after):], config.AfterHooks)
for i := range c.after {
if configurable, ok := c.after[i].(ConfigurableAfterHook); ok {
c.after[i] = configurable.WithConfig(config)
}
}
order.SortStable(c.after, order.OrderedFirstCompare)
}
func (c *client) shallowCopy() *client {
cpy := *c
return &cpy
}
func (c *client) executor(request *Request, resolver TargetResolver, dec DecodeResponseFunc) Retryable {
return func(ctx context.Context) (interface{}, error) {
target, e := url.Parse(request.Path)
// only need to resolve the target if the request.Path is not absolute
if e != nil || !supportedSchemes.Has(target.Scheme) {
target, e = resolver.Resolve(ctx, request)
if e != nil {
return nil, e
}
}
req, e := request.CreateFunc(ctx, request.Method, target)
if e != nil {
return nil, e
}
if e := request.encodeHTTPRequest(ctx, req); e != nil {
return nil, e
}
for _, hook := range c.before {
ctx = hook.Before(ctx, req)
}
resp, e := c.config.HTTPClient.Do(req.WithContext(ctx))
if e != nil {
return nil, e
}
defer func() { _ = resp.Body.Close() }()
for _, hook := range c.after {
ctx = hook.After(ctx, resp)
}
return dec(ctx, resp)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"context"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"net/http"
"net/url"
"time"
)
type Client interface {
// Execute send the provided request and parse the response using provided ResponseOptions
// When using default decoding function:
// - it returns non-nil Response only if the response has 2XX status code
// - it returns non-nil error for 4XX, 5XX status code or any other type of errors
// - the returned error can be casted to *Error
Execute(ctx context.Context, request *Request, opts ...ResponseOptions) (*Response, error)
// WithService create a client with specific service with given instance selectors.
// The returned client is responsible to track service instance changes with help of discovery package,
// and to perform load-balancing and retrying.
// The returned client is goroutine-safe and can be reused
WithService(service string, opts ...SDOptions) (Client, error)
// WithBaseUrl create a client with specific base URL.
// The returned client is responsible to perform retrying.
// The returned client is goroutine-safe and can be reused
WithBaseUrl(baseUrl string) (Client, error)
// WithConfig create a shallow copy of the client with specified config.
// Service (with LB) or BaseURL cannot be changed with this method.
// If any field of provided config is zero value, this value is not applied.
// The returned client is goroutine-safe and can be reused
WithConfig(config *ClientConfig) Client
}
// ClientOptions is used for creating Client and its customizers
type ClientOptions func(opt *ClientOption)
// ClientOption carries initial configurations of Clients
type ClientOption struct {
ClientConfig
SDClient discovery.Client
DefaultSelector discovery.InstanceMatcher
DefaultBeforeHooks []BeforeHook
DefaultAfterHooks []AfterHook
}
// ClientConfig is used to change Client's config
type ClientConfig struct {
// HTTPClient underlying http.Client to use
HTTPClient *http.Client
// BeforeHooks hooks to use before sending HTTP request
BeforeHooks []BeforeHook
// AfterHooks hooks to use before sending HTTP request
AfterHooks []AfterHook
// MaxRetries number of retries in case of error. Negative value means no retry.
// Note: by default, non-2XX response status code error is not retried
MaxRetries int
// RetryBackoff time to wait between retries. Negative means retry immediately
RetryBackoff time.Duration
// RetryCallback allows fine control when and how to retry.
// If set, this override MaxRetries and RetryBackoff
RetryCallback RetryCallback
// Timeout how long to wait for each execution.
// Note: this is total duration including RetryBackoff between each attempt, not per-retry timeout.
Timeout time.Duration
// Logger used for logging request/response
Logger log.ContextualLogger
// Logging configuration of request/response logging
Logging LoggingConfig
}
type LoggingConfig struct {
Level log.LoggingLevel
DetailsLevel LogDetailsLevel
SanitizeHeaders utils.StringSet
ExcludeHeaders utils.StringSet
}
type ClientCustomizer interface {
Customize(opt *ClientOption)
}
type ClientCustomizerFunc func(opt *ClientOption)
func (fn ClientCustomizerFunc) Customize(opt *ClientOption) {
fn(opt)
}
// Hook is used for intercepting is used for ClientConfig and ClientOptions,
type Hook interface {
// Before is invoked after the HTTP request is encoded and before the request is sent.
// The implementing class could also implement order.Ordered interface. Highest order is invoked first
Before(context.Context, *http.Request) context.Context
// After is invoked after HTTP response is returned and before the response is decoded.
// The implementing class could also implement order.Ordered interface. Highest order is invoked first
After(context.Context, *http.Response) context.Context
}
// BeforeHook is used for ClientConfig and ClientOptions,
// The implementing class could also implement order.Ordered interface. Highest order is invoked first
type BeforeHook interface {
// Before is invoked after the HTTP request is encoded and before the request is sent.
Before(context.Context, *http.Request) context.Context
}
// ConfigurableBeforeHook is an additional interface that BeforeHook can implement
type ConfigurableBeforeHook interface {
WithConfig(cfg *ClientConfig) BeforeHook
}
// AfterHook is used for ClientConfig and ClientOptions,
// The implementing class could also implement order.Ordered interface. Highest order is invoked first
type AfterHook interface {
// After is invoked after HTTP response is returned and before the response is decoded.
After(context.Context, *http.Response) context.Context
}
// ConfigurableAfterHook is an additional interface that AfterHook can implement
type ConfigurableAfterHook interface {
WithConfig(cfg *ClientConfig) AfterHook
}
type TargetResolver interface {
Resolve(ctx context.Context, req *Request) (*url.URL, error)
}
type TargetResolverFunc func(ctx context.Context, req *Request) (*url.URL, error)
func (fn TargetResolverFunc) Resolve(ctx context.Context, req *Request) (*url.URL, error) {
return fn(ctx, req)
}
/************************
Common Impl.
************************/
func DefaultConfig() *ClientConfig {
return &ClientConfig{
HTTPClient: http.DefaultClient,
BeforeHooks: []BeforeHook{},
AfterHooks: []AfterHook{},
MaxRetries: 3,
Timeout: 1 * time.Minute,
Logger: logger,
Logging: LoggingConfig{
DetailsLevel: LogDetailsLevelHeaders,
SanitizeHeaders: utils.NewStringSet(HeaderAuthorization),
ExcludeHeaders: utils.NewStringSet(),
},
}
}
// defaultServiceConfig add necessary configs/hooks for internal load balanced service
func defaultServiceConfig() *ClientConfig {
return &ClientConfig{
BeforeHooks: []BeforeHook{HookTokenPassthrough()},
}
}
// defaultExtHostConfig add necessary configs/hooks for external hosts
func defaultExtHostConfig() *ClientConfig {
return &ClientConfig{}
}
func mergeConfig(dst *ClientConfig, src *ClientConfig) {
if dst.HTTPClient == nil {
dst.HTTPClient = src.HTTPClient
}
if dst.Logger == nil {
dst.Logger = src.Logger
}
if dst.Timeout <= 0 {
dst.Timeout = src.Timeout
}
if dst.BeforeHooks == nil {
dst.BeforeHooks = src.BeforeHooks
}
if dst.AfterHooks == nil {
dst.AfterHooks = src.AfterHooks
}
switch {
case dst.MaxRetries < 0:
dst.MaxRetries = 0
case dst.MaxRetries == 0:
dst.MaxRetries = src.MaxRetries
}
switch {
case dst.RetryBackoff < 0:
dst.RetryBackoff = 0
case dst.RetryBackoff == 0:
dst.RetryBackoff = src.RetryBackoff
}
if dst.RetryCallback == nil {
dst.RetryCallback = src.RetryCallback
}
if dst.Logging.SanitizeHeaders == nil {
dst.Logging.SanitizeHeaders = src.Logging.SanitizeHeaders
}
if dst.Logging.ExcludeHeaders == nil {
dst.Logging.ExcludeHeaders = src.Logging.ExcludeHeaders
}
if dst.Logging.DetailsLevel == LogDetailsLevelUnknown {
dst.Logging.DetailsLevel = src.Logging.DetailsLevel
}
if dst.Logging.Level == log.LevelOff {
dst.Logging.Level = src.Logging.Level
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"errors"
"fmt"
. "github.com/cisco-open/go-lanai/pkg/utils/error"
"net/http"
)
const (
// Reserved http client reserved error range
Reserved = 0xcc << ReservedOffset
)
// All "Type" values are used as mask
const (
_ = iota
ErrorTypeCodeInternal = Reserved + iota<<ErrorTypeOffset
ErrorTypeCodeTransport
ErrorTypeCodeResponse
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeInternal
const (
_ = iota
ErrorSubTypeCodeInternal = ErrorTypeCodeInternal + iota<<ErrorSubTypeOffset
ErrorSubTypeCodeDiscovery
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeTransport
const (
_ = iota
ErrorSubTypeCodeTimeout = ErrorTypeCodeTransport + iota<<ErrorSubTypeOffset
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeResponse
const (
_ = iota
ErrorSubTypeCodeServerSide = ErrorTypeCodeResponse + iota<<ErrorSubTypeOffset
ErrorSubTypeCodeClientSide
ErrorSubTypeCodeMedia
)
// ErrorSubTypeCodeInternal
const (
_ = iota
ErrorCodeInternal = ErrorSubTypeCodeInternal + iota
)
// ErrorSubTypeCodeDiscovery
const (
_ = iota
ErrorCodeDiscoveryDown = ErrorSubTypeCodeDiscovery + iota
ErrorCodeNoEndpointFound
)
// ErrorSubTypeCodeTimeout
const (
_ = iota
ErrorCodeServerTimeout = ErrorSubTypeCodeTimeout + iota
)
// ErrorSubTypeCodeMedia
const (
_ = iota
ErrorCodeMediaType = ErrorSubTypeCodeMedia + iota
ErrorCodeSerialization
)
// ErrorSubTypeCodeClientSide
const (
_ = iota
ErrorCodeGenericClientSide = ErrorSubTypeCodeClientSide + iota
ErrorCodeUnauthorized
ErrorCodeForbidden
)
// ErrorSubTypeCodeServerSide
const (
_ = iota
ErrorCodeGenericServerSide = ErrorSubTypeCodeServerSide + iota
)
// ErrorTypes, can be used in errors.Is
var (
ErrorCategoryHttpClient = NewErrorCategory(Reserved, errors.New("error type: http client"))
ErrorTypeInternal = NewErrorType(ErrorTypeCodeInternal, errors.New("error type: internal"))
ErrorTypeTransport = NewErrorType(ErrorTypeCodeTransport, errors.New("error type: http transport"))
ErrorTypeResponse = NewErrorType(ErrorTypeCodeResponse, errors.New("error type: error status code"))
ErrorSubTypeInternalError = NewErrorSubType(ErrorSubTypeCodeInternal, errors.New("error sub-type: internal"))
ErrorSubTypeDiscovery = NewErrorSubType(ErrorSubTypeCodeDiscovery, errors.New("error sub-type: discover"))
ErrorSubTypeTimeout = NewErrorSubType(ErrorSubTypeCodeTimeout, errors.New("error sub-type: server timeout"))
ErrorSubTypeServerSide = NewErrorSubType(ErrorSubTypeCodeServerSide, errors.New("error sub-type: server side"))
ErrorSubTypeClientSide = NewErrorSubType(ErrorSubTypeCodeClientSide, errors.New("error sub-type: client side"))
ErrorSubTypeMedia = NewErrorSubType(ErrorSubTypeCodeMedia, errors.New("error sub-type: server timeout"))
)
// Concrete error, can be used in errors.Is for exact match
var (
ErrorDiscoveryDown = NewError(ErrorCodeDiscoveryDown, "service discovery is not available")
)
func init() {
Reserve(ErrorCategoryHttpClient)
}
type ErrorResponseBody interface {
Error() string
Message() string
Details() map[string]string
}
type ErrorResponse struct {
http.Response
RawBody []byte
Body ErrorResponseBody
}
func (er ErrorResponse) Error() string {
if er.Body == nil {
return er.Status
}
return er.Body.Error()
}
func (er ErrorResponse) Message() string {
if er.Body == nil {
return er.Status
}
return er.Body.Message()
}
// Error can optionally store *http.Response's status code, headers and body
type Error struct {
CodedError
Response *ErrorResponse
}
func (e Error) Error() string {
return e.String()
}
func (e Error) String() string {
switch {
case e.Response == nil:
return e.CodedError.Error()
case e.Response.Body == nil:
return fmt.Sprintf("error HTTP response [%s]", e.Response.Status)
default:
return fmt.Sprintf("error HTTP response [%s]: %s", e.Response.Status, e.Response.Message())
}
}
func (e Error) WithMessage(msg string, args ...interface{}) *Error {
return newError(NewCodedError(e.CodedError.Code(), fmt.Errorf(msg, args...)), e.Response)
}
/*
*********************
Constructors
*********************
*/
func newError(codedErr *CodedError, errResp *ErrorResponse) *Error {
err := &Error{
CodedError: *codedErr,
Response: errResp,
}
return err
}
func NewError(code int64, e interface{}, causes ...interface{}) *Error {
return newError(NewCodedError(code, e, causes...), nil)
}
// NewErrorWithResponse create a Error with ErrorResponse.
// if given "e" is an ErrorResponseBody, it saved into ErrorResponse
func NewErrorWithResponse(code int64, e interface{}, resp *http.Response, rawBody []byte, causes ...interface{}) *Error {
body, _ := e.(ErrorResponseBody)
coded := NewCodedError(code, e, causes...)
errResp := &ErrorResponse{
Response: *resp,
RawBody: rawBody,
Body: body,
}
return newError(coded, errResp)
}
// NewErrorWithStatusCode create a Error with ErrorResponse, and choose error code based on status code
// if given "e" is an ErrorResponseBody, it saved into ErrorResponse
func NewErrorWithStatusCode(e interface{}, resp *http.Response, rawBody []byte, causes ...interface{}) *Error {
var code int64
switch {
case resp.StatusCode == http.StatusUnauthorized:
code = ErrorCodeUnauthorized
case resp.StatusCode == http.StatusForbidden:
code = ErrorCodeForbidden
case resp.StatusCode >= 400 && resp.StatusCode <= 499:
code = ErrorCodeGenericClientSide
case resp.StatusCode >= 500 && resp.StatusCode <= 599:
code = ErrorCodeGenericServerSide
default:
return NewError(ErrorCodeInternal, fmt.Errorf("attempt to create response error with non error status code %d", resp.StatusCode))
}
return NewErrorWithResponse(code, e, resp, rawBody, causes...)
}
func NewInternalError(value interface{}, causes ...interface{}) *Error {
return NewError(ErrorSubTypeCodeInternal, value, causes...)
}
func NewDiscoveryDownError(value interface{}, causes ...interface{}) *Error {
return NewError(ErrorCodeDiscoveryDown, value, causes...)
}
func NewNoEndpointFoundError(value interface{}, causes ...interface{}) *Error {
return NewError(ErrorCodeNoEndpointFound, value, causes...)
}
func NewServerTimeoutError(value interface{}, causes ...interface{}) *Error {
return NewError(ErrorCodeServerTimeout, value, causes...)
}
func NewMediaTypeError(value interface{}, resp *http.Response, rawBody []byte, causes ...interface{}) *Error {
return NewErrorWithResponse(ErrorCodeMediaType, value, resp, rawBody, causes...)
}
func NewSerializationError(value interface{}, resp *http.Response, rawBody []byte, causes ...interface{}) *Error {
return NewErrorWithResponse(ErrorCodeSerialization, value, resp, rawBody, causes...)
}
func NewRequestSerializationError(value interface{}, causes ...interface{}) *Error {
return NewError(ErrorCodeSerialization, value, causes...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"net/http"
"time"
)
const (
HighestReservedHookOrder = -10000
LowestReservedHookOrder = 10000
HookOrderTokenPassthrough = HighestReservedHookOrder + 10
HookOrderRequestLogger = LowestReservedHookOrder
HookOrderResponseLogger = HighestReservedHookOrder
)
const (
logKey = "remote-http"
)
const (
kb = 1024
mb = kb * kb
gb = mb * kb
)
var ctxKeyStartTime = struct{}{}
/*********************
Function Alias
*********************/
// BeforeHookFunc implements Hook with only "before" operation
type BeforeHookFunc func(context.Context, *http.Request) context.Context
func (fn BeforeHookFunc) Before(ctx context.Context, req *http.Request) context.Context {
return fn(ctx, req)
}
// AfterHookFunc implements Hook with only "after" operation
type AfterHookFunc func(context.Context, *http.Response) context.Context
func (fn AfterHookFunc) After(ctx context.Context, resp *http.Response) context.Context {
return fn(ctx, resp)
}
/*********************
Ordered
*********************/
func BeforeHookWithOrder(order int, hook BeforeHook) BeforeHook {
return &orderedBeforeHook{
BeforeHook: hook,
order: order,
}
}
// orderedBeforeHook implements BeforeHook, order.Ordered
type orderedBeforeHook struct {
BeforeHook
order int
}
func (h orderedBeforeHook) Order() int {
return h.order
}
func AfterHookWithOrder(order int, hook AfterHook) AfterHook {
return &orderedAfterHook{
AfterHook: hook,
order: order,
}
}
// orderedAfterHook implements AfterHook, order.Ordered
type orderedAfterHook struct {
AfterHook
order int
}
func (h orderedAfterHook) Order() int {
return h.order
}
/****************************
Token Passthrough Hook
****************************/
func HookTokenPassthrough() BeforeHook {
hook := BeforeHookFunc(func(ctx context.Context, request *http.Request) context.Context {
authHeader := request.Header.Get(HeaderAuthorization)
if authHeader != "" {
return ctx
}
auth, ok := security.Get(ctx).(oauth2.Authentication)
if !ok || !security.IsFullyAuthenticated(auth) || auth.AccessToken() == nil {
return ctx
}
authHeader = fmt.Sprintf("Bearer %s", auth.AccessToken().Value())
request.Header.Set(HeaderAuthorization, authHeader)
return ctx
})
return BeforeHookWithOrder(HookOrderTokenPassthrough, hook)
}
/*************************
Logger Hook
*************************/
type requestLoggerHook struct {
*ClientConfig
}
func (h requestLoggerHook) Order() int {
return HookOrderRequestLogger
}
func(h requestLoggerHook) Before(ctx context.Context, req *http.Request) context.Context {
now := time.Now().UTC()
logRequest(ctx, req, h.Logger, &h.Logging)
return context.WithValue(ctx, ctxKeyStartTime, now)
}
func (h requestLoggerHook) WithConfig(cfg *ClientConfig) BeforeHook {
return requestLoggerHook{ClientConfig: cfg}
}
func HookRequestLogger(cfg *ClientConfig) BeforeHook {
return requestLoggerHook{}.WithConfig(cfg)
}
type responseLoggerHook struct {
*ClientConfig
}
func (h responseLoggerHook) Order() int {
return HookOrderResponseLogger
}
func(h responseLoggerHook) After(ctx context.Context, resp *http.Response) context.Context {
logResponse(ctx, resp, h.Logger, &h.Logging)
return ctx
}
func (h responseLoggerHook) WithConfig(cfg *ClientConfig) AfterHook {
return responseLoggerHook{ClientConfig: cfg}
}
func HookResponseLogger(cfg *ClientConfig) AfterHook {
return responseLoggerHook{}.WithConfig(cfg)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"net/http"
"strings"
"time"
)
type requestLog struct {
Method string `json:"method,omitempty"`
URL string `json:"url,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
Body string `json:"body,omitempty"`
}
type responseLog struct {
requestLog
SC int `json:"statusCode,omitempty"`
RespLength int `json:"length,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
}
func logRequest(ctx context.Context, r *http.Request, logger log.ContextualLogger, logging *LoggingConfig) {
if logging.DetailsLevel < LogDetailsLevelMinimum {
return
}
kv, msg := constructRequestLog(r, logging)
logger.WithContext(ctx).WithKV(logKey, &kv).WithLevel(logging.Level).Printf(msg)
}
func logResponse(ctx context.Context, resp *http.Response, logger log.ContextualLogger, logging *LoggingConfig) {
if logging.DetailsLevel < LogDetailsLevelMinimum {
return
}
kv, msg := constructResponseLog(ctx, resp, logging)
logger.WithContext(ctx).WithKV(logKey, &kv).WithLevel(logging.Level).Printf(msg)
}
func constructRequestLog(r *http.Request, logging *LoggingConfig) (*requestLog, string) {
msg := []string{fmt.Sprintf("[HTTP Request] %s %#v", r.Method, r.URL.RequestURI())}
kv := requestLog{
Method: r.Method,
URL: r.URL.RequestURI(),
}
if logging.DetailsLevel >= LogDetailsLevelHeaders {
var text string
kv.Headers, text = sanitizedHeaders(r.Header, logging.SanitizeHeaders, logging.ExcludeHeaders)
msg = append(msg, text)
}
if logging.DetailsLevel >= LogDetailsLevelFull {
kv.Body = "Request logging is currently unsupported"
msg = append(msg, kv.Body)
}
return &kv, strings.Join(msg, " | ")
}
func constructResponseLog(ctx context.Context, resp *http.Response, logging *LoggingConfig) (*responseLog, string) {
var duration time.Duration
start, ok := ctx.Value(ctxKeyStartTime).(time.Time)
if ok {
duration = time.Since(start).Truncate(time.Microsecond)
}
kv := responseLog{
requestLog: requestLog{
Method: resp.Request.Method,
URL: resp.Request.URL.RequestURI(),
},
SC: resp.StatusCode,
RespLength: int(resp.ContentLength),
Duration: duration,
}
msg := []string{fmt.Sprintf("[HTTP Response] %3d | %10v | %6s",
resp.StatusCode, duration, formatSize(kv.RespLength))}
if logging.DetailsLevel >= LogDetailsLevelHeaders {
var text string
kv.Headers, text = sanitizedHeaders(resp.Header, logging.SanitizeHeaders, logging.ExcludeHeaders)
msg = append(msg, text)
}
if logging.DetailsLevel >= LogDetailsLevelFull {
kv.Body = "Response logging is currently unsupported"
msg = append(msg, kv.Body)
}
return &kv, strings.Join(msg, " | ")
}
func sanitizedHeaders(headers http.Header, sanitize utils.StringSet, exclude utils.StringSet) (map[string]string, string) {
kv := map[string]string{}
msg := make([]string, 0)
for k, v := range headers {
if exclude.Has(k) {
continue
}
value := "******"
if !sanitize.Has(k) {
value = strings.Join(v, " ")
}
kv[k] = value
msg = append(msg, k + `[` + value + `]`)
}
return kv, strings.Join(msg, ", ")
}
func formatSize(n int) string {
switch {
case n < kb:
return fmt.Sprintf("%dB", n)
case n < mb:
return fmt.Sprintf("%.2fKB", float64(n) / kb)
case n < gb:
return fmt.Sprintf("%.2fMB", float64(n) / mb)
default:
return fmt.Sprintf("%.2fGB", float64(n) / gb)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"go.uber.org/fx"
"time"
)
var logger = log.New("HttpClient")
var Module = &bootstrap.Module{
Name: "http-client",
Precedence: bootstrap.HttpClientPrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(bindHttpClientProperties),
fx.Provide(provideHttpClient),
fx.Provide(tracingProvider()),
},
}
func Use() {
bootstrap.Register(Module)
}
// FxClientCustomizers takes providers of ClientCustomizer and wrap them with FxGroup
func FxClientCustomizers(providers ...interface{}) []fx.Annotated {
annotated := make([]fx.Annotated, len(providers))
for i, t := range providers {
annotated[i] = fx.Annotated{
Group: FxGroup,
Target: t,
}
}
return annotated
}
type clientDI struct {
fx.In
Properties HttpClientProperties
DiscClient discovery.Client `optional:"true"`
Customizers []ClientCustomizer `group:"http-client"`
}
func provideHttpClient(di clientDI) Client {
options := []ClientOptions{func(opt *ClientOption) {
opt.SDClient = di.DiscClient
opt.MaxRetries = di.Properties.MaxRetries
opt.Timeout = time.Duration(di.Properties.Timeout)
opt.Logging.Level = di.Properties.Logger.Level
opt.Logging.DetailsLevel = di.Properties.Logger.DetailsLevel
opt.Logging.SanitizeHeaders = utils.NewStringSet(di.Properties.Logger.SanitizeHeaders...)
opt.Logging.ExcludeHeaders = utils.NewStringSet(di.Properties.Logger.ExcludeHeaders...)
}}
for _, customizer := range di.Customizers {
options = append(options, customizer.Customize)
}
return NewClient(options...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"embed"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"strings"
"time"
)
const (
PropertiesPrefix = "integrate.http"
)
const (
LogDetailsLevelUnknown LogDetailsLevel = iota
LogDetailsLevelNone
LogDetailsLevelMinimum
LogDetailsLevelHeaders
LogDetailsLevelFull
)
const (
logDetailsLevelUnknownText = "unknown"
logDetailsLevelNoneText = "off"
logDetailsLevelMinimumText = "minimum"
logDetailsLevelHeadersText = "headers"
logDetailsLevelFullText = "full"
)
var (
logDetailsLevelAtoI = map[string]LogDetailsLevel{
strings.ToLower(logDetailsLevelUnknownText): LogDetailsLevelUnknown,
strings.ToLower(logDetailsLevelNoneText): LogDetailsLevelNone,
strings.ToLower(logDetailsLevelMinimumText): LogDetailsLevelMinimum,
strings.ToLower(logDetailsLevelHeadersText): LogDetailsLevelHeaders,
strings.ToLower(logDetailsLevelFullText): LogDetailsLevelFull,
}
logDetailsLevelItoA = map[LogDetailsLevel]string{
LogDetailsLevelUnknown: logDetailsLevelUnknownText,
LogDetailsLevelNone: logDetailsLevelNoneText,
LogDetailsLevelMinimum: logDetailsLevelMinimumText,
LogDetailsLevelHeaders: logDetailsLevelHeadersText,
LogDetailsLevelFull: logDetailsLevelFullText,
}
)
type LogDetailsLevel int
func (l LogDetailsLevel) String() string {
if s, ok := logDetailsLevelItoA[l]; ok {
return s
}
return logDetailsLevelNoneText
}
func (l LogDetailsLevel) MarshalText() ([]byte, error) {
return []byte(l.String()), nil
}
func (l *LogDetailsLevel) UnmarshalText(data []byte) error {
value := strings.ToLower(string(data))
if v, ok := logDetailsLevelAtoI[value]; ok {
*l = v
}
return nil
}
//go:embed defaults-integrate-http.yml
var defaultConfigFS embed.FS
type HttpClientProperties struct {
MaxRetries int `json:"max-retries"` // negative value means no retry
Timeout utils.Duration `json:"timeout"`
Logger LoggerProperties `json:"logger"`
}
type LoggerProperties struct {
Level log.LoggingLevel `json:"level"`
DetailsLevel LogDetailsLevel `json:"details-level"`
SanitizeHeaders utils.CommaSeparatedSlice `json:"sanitize-headers"`
ExcludeHeaders utils.CommaSeparatedSlice `json:"exclude-headers"`
}
func newHttpClientProperties() *HttpClientProperties {
return &HttpClientProperties{
MaxRetries: 3,
Timeout: utils.Duration(1 * time.Minute),
Logger: LoggerProperties{
Level: log.LevelDebug,
DetailsLevel: LogDetailsLevelHeaders,
SanitizeHeaders: utils.CommaSeparatedSlice{HeaderAuthorization},
ExcludeHeaders: utils.CommaSeparatedSlice{},
},
}
}
func bindHttpClientProperties(ctx *bootstrap.ApplicationContext) HttpClientProperties {
props := newHttpClientProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind HttpClientProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)
// CreateRequestFunc is a function to create http.Request with given context, method and target URL
type CreateRequestFunc func(ctx context.Context, method string, target *url.URL) (*http.Request, error)
// EncodeRequestFunc is a function to modify http.Request for encoding given value
type EncodeRequestFunc func(ctx context.Context, req *http.Request, val interface{}) error
// RequestOptions used to configure Request in NewRequest
type RequestOptions func(r *Request)
// Request is wraps all information about the request
type Request struct {
Path string
Method string
Params map[string]string
Headers http.Header
Body interface{}
BodyEncodeFunc EncodeRequestFunc
CreateFunc CreateRequestFunc
}
func NewRequest(path, method string, opts ...RequestOptions) *Request {
r := Request{
Path: path,
Method: method,
Params: map[string]string{},
Headers: http.Header{},
BodyEncodeFunc: EncodeJSONRequestBody,
CreateFunc: defaultRequestCreateFunc,
}
for _, f := range opts {
f(&r)
}
return &r
}
func (r Request) encodeHTTPRequest(ctx context.Context, req *http.Request) error {
// set headers
for k := range r.Headers {
req.Header.Set(k, r.Headers.Get(k))
}
// set params
r.applyParams(req)
return r.BodyEncodeFunc(ctx, req, r.Body)
}
func (r Request) applyParams(req *http.Request) {
if len(r.Params) == 0 {
return
}
queries := make([]string, len(r.Params))
i := 0
for k, v := range r.Params {
queries[i] = k + "=" + url.QueryEscape(v)
i++
}
req.URL.RawQuery = strings.Join(queries, "&")
}
/**********************
Defaults
**********************/
func EncodeJSONRequestBody(_ context.Context, r *http.Request, body interface{}) error {
if body == nil {
r.Body = nil
r.GetBody = nil
r.ContentLength = 0
return nil
}
if len(r.Header.Values(HeaderContentType)) == 0 {
r.Header.Set(HeaderContentType, MediaTypeJson)
}
var b bytes.Buffer
r.Body = io.NopCloser(&b)
err := json.NewEncoder(&b).Encode(body)
if err != nil {
return err
}
buf := b.Bytes()
r.GetBody = func() (io.ReadCloser, error) {
r := bytes.NewReader(buf)
return io.NopCloser(r), nil
}
r.ContentLength = int64(b.Len())
return nil
}
func EncodeURLEncodedRequestBody(_ context.Context, r *http.Request, body interface{}) error {
values, ok := body.(url.Values)
if !ok {
return NewRequestSerializationError(fmt.Errorf("www-form-urlencoded body expects url.Values but got %T", body))
}
if len(r.Header.Values(HeaderContentType)) == 0 {
r.Header.Set(HeaderContentType, MediaTypeFormUrlEncoded)
}
encoded := values.Encode()
r.GetBody = func() (io.ReadCloser, error) {
r := strings.NewReader(encoded)
return io.NopCloser(r), nil
}
r.Body, _ = r.GetBody()
r.ContentLength = int64(len(encoded))
return nil
}
func defaultRequestCreateFunc(ctx context.Context, method string, target *url.URL) (*http.Request, error) {
return http.NewRequestWithContext(ctx, method, target.String(), nil)
}
/**********************
Request Options
**********************/
func WithoutHeader(key string) RequestOptions {
switch {
case key == "":
return noop()
default:
return func(r *Request) {
r.Headers.Del(key)
}
}
}
func WithHeader(key, value string) RequestOptions {
switch {
case key == "" || value == "":
return noop()
default:
return func(r *Request) {
r.Headers.Add(key, value)
}
}
}
func WithParam(key, value string) RequestOptions {
switch {
case key == "":
return noop()
case value == "":
return func(r *Request) {
delete(r.Params, key)
}
default:
return func(r *Request) {
r.Params[key] = value
}
}
}
func WithBody(body interface{}) RequestOptions {
return func(r *Request) {
r.Body = body
}
}
func WithRequestBodyEncoder(enc EncodeRequestFunc) RequestOptions {
return func(r *Request) {
r.BodyEncodeFunc = enc
if r.BodyEncodeFunc == nil {
r.BodyEncodeFunc = EncodeJSONRequestBody
}
}
}
func WithRequestCreator(enc CreateRequestFunc) RequestOptions {
return func(r *Request) {
r.CreateFunc = enc
if r.CreateFunc == nil {
r.CreateFunc = defaultRequestCreateFunc
}
}
}
func WithBasicAuth(username, password string) RequestOptions {
raw := username + ":" + password
b64 := base64.StdEncoding.EncodeToString([]byte(raw))
auth := "Basic " + b64
return WithHeader(HeaderAuthorization, auth)
}
func WithUrlEncodedBody(body url.Values) RequestOptions {
return mergeRequestOptions(WithBody(body), WithRequestBodyEncoder(EncodeURLEncodedRequestBody))
}
func mergeRequestOptions(opts...RequestOptions) RequestOptions {
return func(r *Request) {
for _, fn := range opts {
fn(r)
}
}
}
func noop() func(r *Request) {
return func(_ *Request) {
// noop
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/discovery"
"net/url"
"path"
"time"
)
/****************************
SD TargetResolver
****************************/
// SDOptions allows control of endpointCache behavior.
type SDOptions func(opt *SDOption)
type SDOption struct {
// Selector Used to filter targets during service discovery.
// Default: discovery.InstanceIsHealthy()
Selector discovery.InstanceMatcher
// InvalidateOnError Whether to return previously known targets in case service discovery is temporarily unavailable
// Default: true
InvalidateOnError bool
// InvalidateTimeout How long to keep previously known targets in case service discovery is temporarily unavailable.
// < 0: Always use previously known targets, equivalent to InvalidateOnError = false
// == 0: Never use previously known targets.
// > 0: Use previously known targets for the specified duration since the first error received from SD client
// Default: -1 if InvalidateOnError = false, 0 if InvalidateOnError = true
InvalidateTimeout time.Duration
// Scheme HTTP scheme to use.
// If not set, the actual scheme is resolved from target instance's Meta/Tags and DefaultScheme value.
// Possible values: "http", "https", "" (empty string).
// Default: ""
Scheme string
// DefaultScheme Default HTTP scheme to use, if Scheme is not set and resolver cannot resolve scheme from Meta/Tags.
// Possible values: "http", "https.
// Default: "http"
DefaultScheme string
// ContextPath Path prefix for any given Request.
// If not set, the context path is resolved from target instance's Meta/Tags.
// e.g. "/auth/api"
// Default: ""
ContextPath string
}
// SDTargetResolver implements TargetResolver interface that use the discovery.Instancer to resolve target's address.
// It also attempts to resolve the http scheme and context path from instance's tags/meta.
// In case of failed service discovery with error, this resolver keeps using previously found instances assuming they
// are still good of period of time configured by SDOption.
// Currently, this resolver only support round-robin load balancing.
type SDTargetResolver struct {
SDOption
instancer discovery.Instancer
balancer balancer[*discovery.Instance]
}
// NewSDTargetResolver creates a TargetResolver that work with discovery.Instancer.
// See SDTargetResolver
func NewSDTargetResolver(instancer discovery.Instancer, opts ...SDOptions) (*SDTargetResolver, error) {
opt := SDOption{
Selector: discovery.InstanceIsHealthy(),
DefaultScheme: "http",
InvalidateOnError: true,
}
for _, f := range opts {
f(&opt)
}
// some validation
if !opt.InvalidateOnError {
opt.InvalidateTimeout = -1
} else if opt.InvalidateTimeout < 0 {
opt.InvalidateTimeout = 0 // invalidate immediately
}
return &SDTargetResolver{
SDOption: opt,
instancer: instancer,
balancer: newRoundRobinBalancer[*discovery.Instance](),
}, nil
}
func (ke *SDTargetResolver) Resolve(_ context.Context, req *Request) (*url.URL, error) {
svc := ke.instancer.Service()
if svc == nil {
return nil, NewNoEndpointFoundError(fmt.Errorf("cannot find service [%s]", ke.instancer.ServiceName()))
} else if svc.Err != nil && !ke.handleError(svc) {
return nil, NewDiscoveryDownError(fmt.Errorf("cannot find service [%s]", ke.instancer.ServiceName()), svc.Err)
}
// prepare endpoints
inst, e := ke.balancer.Balance(svc.Instances(ke.Selector))
if e != nil || inst == nil {
return nil, NewNoEndpointFoundError(fmt.Errorf("cannot find service [%s]", ke.instancer.ServiceName()))
}
return ke.targetURL(inst, req)
}
func (ke *SDTargetResolver) targetURL(inst *discovery.Instance, req *Request) (target *url.URL, err error) {
ctxPath := ke.ContextPath
if len(ctxPath) == 0 && inst.Meta != nil {
ctxPath, _ = inst.Meta[discovery.InstanceMetaKeyContextPath]
}
scheme := ke.Scheme
if len(scheme) == 0 {
if m, e := httpInstanceMatcher.Matches(inst); m && e == nil {
scheme = "http"
} else if m, e := httpsInstanceMatcher.Matches(inst); m && e == nil {
scheme = "https"
} else {
scheme = ke.DefaultScheme
}
}
host := inst.Address
if inst.Port > 0 && inst.Port <= 0xffff {
host = fmt.Sprintf("%s:%d", inst.Address, inst.Port)
}
target = &url.URL{
Scheme: scheme,
Host: host,
Path: path.Join(ctxPath, req.Path),
}
return
}
// handleError is NOT goroutine-safe and returns a boolean indicating last known endpoints should be returned
func (ke *SDTargetResolver) handleError(svc *discovery.Service) bool {
switch {
case ke.InvalidateTimeout < 0:
// nothing to do
return true
case ke.InvalidateTimeout == 0 || svc.FirstErrAt.IsZero():
// do not return last known
return false
default:
return svc.FirstErrAt.Add(ke.InvalidateTimeout).Before(time.Now())
}
}
package httpclient
import (
"context"
"fmt"
"net/url"
"path"
)
/***********************
BaseUrlTargetResolver
***********************/
func NewStaticTargetResolver(baseUrl string) (TargetResolverFunc, error) {
base, e := url.Parse(baseUrl)
if e != nil {
return nil, e
} else if !base.IsAbs() {
return nil, fmt.Errorf(`expect abslolute base URL, but got "%s"`, baseUrl)
}
return func(ctx context.Context, req *Request) (*url.URL, error) {
uri := *base
uri.Path = path.Clean(path.Join(base.Path, req.Path))
return &uri, nil
}, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"net/http"
"reflect"
)
type DecodeResponseFunc func(context.Context, *http.Response) (response interface{}, err error)
type Response struct {
StatusCode int
Headers http.Header
Body interface{}
RawBody []byte `json:"-"`
}
type ResponseOptions func(opt *responseOption)
type responseOption struct {
body interface{}
errBody ErrorResponseBody
decodeFunc DecodeResponseFunc
}
func fallbackResponseOptions(opt *responseOption) {
if opt.decodeFunc == nil {
if opt.body == nil {
opt.body = &map[string]interface{}{}
}
if opt.errBody == nil {
opt.errBody = &defaultErrorBody{}
}
opt.decodeFunc = makeJsonDecodeResponseFunc(opt)
}
}
// JsonBody returns a ResponseOptions that specify interface{} to use for parsing response body as JSON
func JsonBody(body interface{}) ResponseOptions {
return func(opt *responseOption) {
opt.body = body
}
}
// JsonErrorBody returns a ResponseOptions that specify interface{} to use for parsing error response as JSON
func JsonErrorBody(errBody ErrorResponseBody) ResponseOptions {
return func(opt *responseOption) {
opt.errBody = errBody
}
}
// CustomResponseDecoder returns a ResponseOptions that specify custom decoding function of http.Response
// this options overwrite JsonBody and JsonErrorBody
func CustomResponseDecoder(dec DecodeResponseFunc) ResponseOptions {
return func(opt *responseOption) {
opt.decodeFunc = dec
}
}
func makeJsonDecodeResponseFunc(opt *responseOption) DecodeResponseFunc {
if opt.decodeFunc != nil {
return opt.decodeFunc
}
// standard decode func
return func(ctx context.Context, resp *http.Response) (response interface{}, err error) {
if resp.StatusCode > 299 {
return nil, handleStatusCodeError(resp, opt.errBody)
}
// decode
body := opt.body
raw, e := decodeJsonBody(resp, body)
if e != nil {
return nil, e
}
// dereference if needed
rv := reflect.ValueOf(body)
if rv.Kind() == reflect.Ptr {
ev := rv.Elem()
switch ev.Kind() {
case reflect.Map, reflect.Slice, reflect.Interface:
body = ev.Interface()
default:
}
}
return &Response{
StatusCode: resp.StatusCode,
Headers: resp.Header,
Body: body,
RawBody: raw,
}, nil
}
}
func handleStatusCodeError(resp *http.Response, errBody interface{}) error {
raw, e := decodeJsonBody(resp, errBody)
if e != nil {
var httpE *Error
if errors.As(e, &httpE) {
return httpE.WithMessage("unable to parse error response: %v", e)
} else {
return e
}
}
return NewErrorWithStatusCode(errBody, resp, raw)
}
// decodeJsonBody read body from http.Response and decode into given "body"
// function panic if "body" is nil
func decodeJsonBody(resp *http.Response, body interface{}) ([]byte, error) {
defer func() {_ = resp.Body.Close()}()
// check media type
if e := validateMediaType(MediaTypeJson, resp); e != nil {
return nil, e
}
// decode, and keep the raw bytes
var data []byte
data, e := io.ReadAll(resp.Body)
if e != nil {
return nil, NewSerializationError(fmt.Errorf("response IO error: %s", e), resp, data)
}
if len(data) > 0 {
if e := json.Unmarshal(data, body); e != nil {
return data, NewSerializationError(fmt.Errorf("response unmarshal error: %s", e), resp, data)
}
}
return data, nil
}
func validateMediaType(expected string, resp *http.Response) *Error {
contentType := resp.Header.Get(HeaderContentType)
mediaType, _, e := mime.ParseMediaType(contentType)
if e != nil {
return NewMediaTypeError(fmt.Errorf("received invalid content type %s", contentType), resp, nil, e)
}
if mediaType != MediaTypeJson {
return NewMediaTypeError(fmt.Errorf("unsupported media type: %s, expected %s", mediaType, expected), resp, nil)
}
return nil
}
/*************************
Error Unmarshal
*************************/
type jsonErrorBody struct {
Error string `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Desc string `json:"error_description,omitempty"`
Details map[string]string `json:"details,omitempty"`
}
// defaultErrorBody implements ErrorResponseBody, json.Marshaler, json.Unmarshaler
type defaultErrorBody struct {
jsonErrorBody
}
func (b defaultErrorBody) Error() string {
return b.jsonErrorBody.Error
}
func (b defaultErrorBody) Message() string {
if b.jsonErrorBody.Message == "" {
return b.jsonErrorBody.Desc
}
return b.jsonErrorBody.Message
}
func (b defaultErrorBody) Details() map[string]string {
return b.jsonErrorBody.Details
}
// MarshalJSON implements json.Marshaler
func (b defaultErrorBody) MarshalJSON() ([]byte, error) {
return json.Marshal(b.jsonErrorBody)
}
// UnmarshalJSON implements json.Unmarshaler
func (b *defaultErrorBody) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &b.jsonErrorBody)
}
package httpclient
import (
"context"
"time"
)
type Retryable func(ctx context.Context) (interface{}, error)
// RetryCallback retry control function.
// RetryCallback is executed each time when non-nil error is returned.
// "n": indicate the iteration number of attempts.
// "rs" and "err" indicate the result of the current attempt
// RetryCallback returns whether Retryable need to keep trying and optionally wait for "backoff" before next attempt
type RetryCallback func(n int, rs interface{}, err error) (shouldContinue bool, backoff time.Duration)
// Try keep trying to execute the Retryable until
// 1. No error is returned
// 2. Timeout reached
// 3. RetryCallback tells it to stop
// The Retryable is executed in separated goroutine, and the RetryCallback is invoked in current goroutine
// when non-nil error is returned.
// If the execution finished without any successful result, latest error is returned if available, otherwise context.Err()
func (r Retryable) Try(ctx context.Context, timeout time.Duration, cb RetryCallback) (interface{}, error) {
if cb == nil {
cb = func(_ int, _ interface{}, _ error) (bool, time.Duration) {
return true, 0
}
}
type result struct {
value interface{}
err error
}
timoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
var lastErr error
for i := 1; ; i++ {
rsCh := make(chan result, 1)
go func() {
var rs result
rs.value, rs.err = r(timoutCtx)
rsCh <- rs
close(rsCh)
}()
select {
case <-timoutCtx.Done():
if lastErr == nil {
lastErr = timoutCtx.Err()
}
return nil, lastErr
case rs := <-rsCh:
if rs.err == nil {
return rs.value, nil
}
lastErr = rs.err
switch again, backoff := cb(i, rs.value, rs.err); {
case !again:
return rs.value, rs.err
case backoff > 0:
time.Sleep(backoff)
}
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package httpclient
import (
"context"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"go.uber.org/fx"
"net"
"net/http"
"strconv"
)
const tracingOpName = "remote-http"
type tracingCustomizer struct {
tracer opentracing.Tracer
}
func tracingProvider() fx.Annotated {
return FxClientCustomizers(newTracingCustomizer)[0]
}
type tracingDI struct {
fx.In
Tracer opentracing.Tracer `optional:"true"`
}
func newTracingCustomizer(di tracingDI) ClientCustomizer {
return &tracingCustomizer{
tracer: di.Tracer,
}
}
func (c *tracingCustomizer) Customize(opt *ClientOption) {
if c.tracer == nil {
return
}
opt.DefaultBeforeHooks = append(opt.DefaultBeforeHooks,
startSpanHook(c.tracer),
)
opt.DefaultAfterHooks = append(opt.DefaultAfterHooks,
finishSpanHook(c.tracer),
)
}
func startSpanHook(tracer opentracing.Tracer) BeforeHook {
fn := func(ctx context.Context, req *http.Request) context.Context {
name := tracingOpName + " " + req.Method
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCClientEnum),
tracing.SpanTag("method", req.Method),
tracing.SpanTag("url", req.URL.RequestURI()),
}
// standard tags
hostname := req.URL.Host
var port int
if host, portString, e := net.SplitHostPort(req.URL.Host); e == nil {
hostname = host
port, _ = strconv.Atoi(portString)
}
opts = append(opts,
tracing.SpanHttpMethod(req.Method),
tracing.SpanHttpUrl(req.URL.String()),
func(span opentracing.Span) {
ext.PeerHostname.Set(span, hostname)
if port != 0 {
ext.PeerPort.Set(span, uint16(port))
}
},
)
// propagation
opts = append(opts, spanPropagation(req, tracer))
return tracing.WithTracer(tracer).
WithOpName(name).
WithOptions(opts...).
DescendantOrNoSpan(ctx)
}
return BeforeHookWithOrder(order.Highest, BeforeHookFunc(fn))
}
func finishSpanHook(tracer opentracing.Tracer) AfterHook {
fn := func(ctx context.Context, response *http.Response) context.Context {
op := tracing.WithTracer(tracer).
WithOptions(
tracing.SpanTag("sc", response.StatusCode),
tracing.SpanHttpStatusCode(response.StatusCode),
)
return op.FinishAndRewind(ctx)
}
return AfterHookWithOrder(order.Lowest, AfterHookFunc(fn))
}
func spanPropagation(req *http.Request, tracer opentracing.Tracer) tracing.SpanOption {
return func(span opentracing.Span) {
_ = tracer.Inject(
span.Context(),
opentracing.HTTPHeaders,
opentracing.HTTPHeadersCarrier(req.Header),
)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"embed"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"time"
)
const (
PropertiesPrefix = "integrate.security"
)
//go:embed defaults-integrate-security.yml
var DefaultConfigFS embed.FS
//goland:noinspection GoNameStartsWithPackageName
type SecurityIntegrationProperties struct {
// How much time after a failed attempt, when re-try is allowed. Before this period pass,
// integration framework will not re-attempt switching context to same combination of username and tenant name
FailureBackOff utils.Duration `json:"failure-back-off"`
// How much time that security context is guaranteed to be valid after requested.
// when such validity cannot be guaranteed (e.g. this value is longer than token's validity),
// we use FailureBackOff and re-request new token after `back-off` passes
GuaranteedValidity utils.Duration `json:"guaranteed-validity"`
Endpoints AuthEndpointsProperties `json:"endpoints"`
Client ClientCredentialsProperties `json:"client"`
Accounts AccountsProperties `json:"accounts"`
}
type ClientCredentialsProperties struct {
ClientId string `json:"client-id"`
ClientSecret string `json:"secret"`
}
type AuthEndpointsProperties struct {
// BaseUrl is used to override service discovery and load-balancing
// When set, ServiceName, Scheme and ContextPath are ignored
BaseUrl string `json:"base-url"`
// ServiceName The name of auth service, used by service discovery to authentication/authorization URL
ServiceName string `json:"service-name"`
// Scheme HTTP scheme ("http"/"https") of auth service, in case it's not resolvable from service registry
Scheme string `json:"scheme"`
// ContextPath The path prefix of all endpoints, in case it's not resolvable from service registry
ContextPath string `json:"context-path"`
// PasswordLogin Path of password login endpoint
PasswordLogin string `json:"password-login"`
// SwitchContext Path of switch tenant/user endpoint
SwitchContext string `json:"switch-context"`
}
type AccountsProperties struct {
Default AccountCredentialsProperties `json:"default"`
Additional []AccountCredentialsProperties `json:"additional"`
}
type AccountCredentialsProperties struct {
Username string `json:"username"`
Password string `json:"password"`
SystemAccount bool `json:"system-account"`
}
// NewSecurityIntegrationProperties create a DataProperties with default values
func NewSecurityIntegrationProperties() *SecurityIntegrationProperties {
return &SecurityIntegrationProperties{
FailureBackOff: utils.Duration(300 * time.Second),
GuaranteedValidity: utils.Duration(30 * time.Second),
Endpoints: AuthEndpointsProperties{
ServiceName: "authservice",
Scheme: "http",
PasswordLogin: "/v2/token",
SwitchContext: "/v2/token",
},
Client: ClientCredentialsProperties{
ClientId: "nfv-service",
ClientSecret: "nfv-service-secret",
},
Accounts: AccountsProperties{
Default: AccountCredentialsProperties{
Username: "system",
Password: "system",
SystemAccount: true,
},
},
}
}
// BindSecurityIntegrationProperties create and bind SessionProperties, with a optional prefix
func BindSecurityIntegrationProperties(ctx *bootstrap.ApplicationContext) SecurityIntegrationProperties {
props := NewSecurityIntegrationProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind SecurityIntegrationProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scope
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/integrate/httpclient"
"github.com/cisco-open/go-lanai/pkg/integrate/security/seclient"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"reflect"
"time"
)
type authenticateFunc func(ctx context.Context, pKey *cKey) (security.Authentication, error)
type managerBase struct {
cache *cache
tokenStoreReader oauth2.TokenStoreReader
failureBackOff time.Duration
guaranteedValidity time.Duration
beforeStartHooks []ScopeOperationHook
afterEndHooks []ScopeOperationHook
}
func (b *managerBase) DoStartScope(ctx context.Context, scope *Scope, authFunc authenticateFunc) (context.Context, error) {
for _, hook := range b.beforeStartHooks {
ctx = hook(ctx, scope)
}
auth, e := b.GetOrAuthenticate(ctx, scope.cacheKey, scope.time, authFunc)
if e != nil {
return nil, e
}
// set new security auth and return
scoped := newScopedContext(ctx, scope, auth)
return scoped, nil
}
func (b *managerBase) End(ctx context.Context) (ret context.Context) {
rollback := ctx.Value(ctxKeyRollback)
scope, _ := ctx.Value(ctxKeyScope).(*Scope)
switch c := rollback.(type) {
case context.Context:
ret = c
default:
ret = ctx
}
for _, hook := range b.afterEndHooks {
ret = hook(ret, scope)
}
return
}
func (b *managerBase) GetOrAuthenticate(ctx context.Context, pKey *cKey, rTime time.Time, authFunc authenticateFunc) (ret security.Authentication, err error) {
return b.cache.GetOrLoad(ctx, pKey , b.cacheLoadFunc(rTime, authFunc), b.cacheValidateFunc())
}
func (b *managerBase) resolveUser(auth security.Authentication) (username, userId string, err error) {
if !security.IsFullyAuthenticated(auth) {
return "", "", fmt.Errorf("not currently authenticated")
}
switch details := auth.Details().(type) {
case security.UserDetails:
username = details.Username()
userId = details.UserId()
default:
username, err = security.GetUsername(auth)
}
return
}
// normalizeTargetUser check if currently authenticated user is same user of target user
// if is same user, set target username and remove target userId
// use case:
// currently logged in as "user1" with userId="user1-id" and scope indicate target scope.userId="user1-id"
// normalize result: scope.userId = "", scope.username="user1"
func (b *managerBase) normalizeTargetUser(auth security.Authentication, scope *Scope) {
if scope.username == "" && scope.userId == "" || !b.isSameUser(scope.username, scope.userId, auth) {
return
}
username, _, e := b.resolveUser(auth)
if e != nil {
return
}
scope.username = username
scope.userId = ""
}
func (b *managerBase) prepareCacheKey(scope *Scope, srcUsername string) {
scope.cacheKey = &cKey{
src: srcUsername,
username: scope.username,
userId: scope.userId,
tenantExternalId: scope.tenantExternalId,
tenantId: scope.tenantId,
}
}
func (b *managerBase) isSameUser(username, userId string, auth security.Authentication) bool {
un, id, e := b.resolveUser(auth)
if e != nil {
return false
}
return username != "" && username == un || userId != "" && userId == id
}
func (b *managerBase) isSameTenant(tenantExternalId, tenantId string, auth security.Authentication) bool {
if tenantExternalId == "" && tenantId == "" {
return true
}
switch details := auth.Details().(type) {
case security.TenantDetails:
return tenantId != "" && tenantId == details.TenantId() || tenantExternalId != "" && tenantExternalId == details.TenantExternalId()
default:
return false
}
}
func (b *managerBase) cacheValidateFunc() validateFunc {
return func(ctx context.Context, auth entryValue) bool {
if auth == nil || auth.AccessToken() == nil {
return false
}
_, e := b.tokenStoreReader.ReadAccessToken(ctx, auth.AccessToken().Value())
return e == nil
}
}
func (b *managerBase) cacheLoadFunc(rTime time.Time, authFunc authenticateFunc) loadFunc {
return func(ctx context.Context, k cKey) (entryValue, time.Time, error) {
auth, e := authFunc(ctx, &k)
// calculate exp time based on backoff time
errExp := rTime.UTC().Add(b.failureBackOff)
if e != nil {
return nil, b.calculateBackOffExp(e, errExp), e
}
if auth == nil {
// sanity check, this shouldn't happen
return nil, errExp, fmt.Errorf("[Internal Error] authenticateFunc returned nil oauth without error")
}
// try to guarantee token's validity by setting expire time a little earlier than auth's exp time
oauth := auth.(oauth2.Authentication)
tokenExp := oauth.AccessToken().ExpiryTime().UTC()
exp := tokenExp.Add(-1 * b.guaranteedValidity)
if exp.Before(rTime) {
// edge case, we cannot guarantee token's validity, such error would insists until this token expires
// we'd still return the token since it at least valid now,
// but we set expire time to back-off time or token expiry, which ever is earlier
if tokenExp.Before(errExp) {
exp = tokenExp
} else {
exp = errExp
}
}
return oauth, exp, nil
}
}
func (b *managerBase) convertToAuthentication(ctx context.Context, result *seclient.Result) (oauth2.Authentication, error) {
// TODO we could leverage IDToken and probably Remote token API
auth, e := b.tokenStoreReader.ReadAuthentication(ctx, result.Token.Value(), oauth2.TokenHintAccessToken)
if e != nil {
return nil, e
}
// perform some checks
switch {
case auth.State() < security.StateAuthenticated:
return nil, fmt.Errorf("token is not associated with an authenticated session")
case auth.OAuth2Request().ClientId() == "":
return nil, fmt.Errorf("token is not issued to a valid client")
case auth.UserAuthentication() != nil && reflect.ValueOf(auth.UserAuthentication().Principal()).IsZero():
return nil, fmt.Errorf("token is not authorized by a valid user")
}
return auth, nil
}
func (b *managerBase) calculateBackOffExp(err error, defaultValue time.Time) time.Time {
switch {
case errors.Is(err, httpclient.ErrorSubTypeDiscovery):
return time.Now().UTC().Add(10 * time.Second)
default:
return defaultValue
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scope
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"sync"
"sync/atomic"
"time"
)
type cKey struct {
src string // source username
username string // target username
userId string // target userId
tenantExternalId string // target tenantExternalId
tenantId string // target tenantId
}
func (k cKey) String() string {
user := k.username
if user == "" {
user = k.userId
}
tenant := k.tenantId
if tenant == "" {
tenant = k.tenantExternalId
}
return fmt.Sprintf("%s->%s@%s", k.src, user, tenant)
}
type entryValue oauth2.Authentication
// cEntry carries cache entry.
// after the sync.WaitGroup's Wait() func, value, expire and lastErr should be immutable
// and isLoaded() should return true
type cEntry struct {
wg sync.WaitGroup
value entryValue
expire time.Time
lastErr error
// invalid indicates whether "get" function should return it as existing entry.
// once an entry become "invalid", it's equivalent to "not exist"
// invalid can only be set from False to True atomically.
// when invalid flag == 1, it's guaranteed that the entry is not valid and such status is immutable
// when invalid flag == 0, it's NOT guaranteed the entry is "valid", goroutines should also check other fields after sync.WaitGroup's Wait()
invalid uint64
// loaded is used for evicting function to decide if expire is available without waiting on loader
// because evicting func is executed periodically to act on "loaded" entries, and loaded can only be set from False to True,
// it's not necessary to use lock to coordinate, atomic op is sufficient
// other threads/goroutines should use sync.WaitGroup's Wait()
loaded uint64
}
// isExpired is NOT goroutine-safe
func (ce *cEntry) isExpired() bool {
return !ce.expire.IsZero() && !time.Now().Before(ce.expire)
}
// isInvalidated is atomic operation
func (ce *cEntry) isInvalidated() bool {
return atomic.LoadUint64(&ce.invalid) != 0
}
// invalidate is atomic operation
func (ce *cEntry) invalidate() {
atomic.StoreUint64(&ce.invalid, 1)
}
// isLoaded is atomic operation
func (ce *cEntry) isLoaded() bool {
return atomic.LoadUint64(&ce.loaded) != 0
}
// markLoaded is atomic operation
func (ce *cEntry) markLoaded() {
atomic.StoreUint64(&ce.loaded, 1)
}
type loadFunc func(ctx context.Context, k cKey) (v entryValue, exp time.Time, err error)
type newFunc func(context.Context, *cKey) *cEntry
type validateFunc func(context.Context, entryValue) bool
type cacheOptions func(opt *cacheOption)
type cacheOption struct {
Heartbeat time.Duration
}
type cache struct {
mtx sync.RWMutex
store map[cKey]*cEntry
reaper *time.Ticker
}
func newCache(opts ...cacheOptions) (ret *cache) {
opt := cacheOption{
Heartbeat: 10 * time.Minute,
}
for _, fn := range opts {
fn(&opt)
}
ret = &cache{
store: map[cKey]*cEntry{},
}
ret.startReaper(opt.Heartbeat)
return
}
func (c *cache) GetOrLoad(ctx context.Context, k *cKey, loader loadFunc, validator validateFunc) (entryValue, error) {
// maxRetry should be > 0, no upper limit
// 1. when entry exists and not expired/invalidated, no retry
// 2. when entry is newly created, no retry
// 3. when entry exists but expired/invalidated, mark it invalidated and retry
const maxRetry = 2
for i := 0; i <= maxRetry; i++ {
// getOrNew guarantee that only one goroutine create new entry (if needed)
// aka, getOrNew uses cache-wise RW lock to ensure such behavior
entry, isNew := c.getOrNew(ctx, k, c.newEntryFunc(loader))
if entry == nil {
return nil, fmt.Errorf("[Internal Error] security Scope cache returns nil entry")
}
// wait for entry to load
entry.wg.Wait()
// from now on, entry content become immutable
// check entry validity
// note that we skip validation if the entry is freshly created
if isNew || !entry.isExpired() && (entry.lastErr != nil || validator(ctx, entry.value)) {
// valid entry
if entry.lastErr != nil {
return nil, entry.lastErr
}
return entry.value, nil
}
entry.invalidate()
}
return nil, fmt.Errorf("unable to load valid entry")
}
// newEntryFunc returns a newFunc that create an entry and kick off "loader" in separate goroutine
// this method is not goroutine safe.
func (c *cache) newEntryFunc(loader loadFunc) newFunc {
return func(ctx context.Context, key *cKey) *cEntry {
ret := &cEntry{}
ret.wg.Add(1)
// schedule load
go c.load(ctx, key, ret, loader)
return ret
}
}
// load execute given loader and sent entry's sync.WaitGroup Done()
// this method is not goroutine-safe and should be invoked only once
func (c *cache) load(ctx context.Context, key *cKey, entry *cEntry, loader loadFunc) {
v, exp, e := loader(ctx, *key)
entry.value = v
entry.expire = exp
entry.lastErr = e
entry.markLoaded()
entry.wg.Done()
}
// getOrNew return existing entry or create and set using newIfAbsent
// this method is goroutine-safe
func (c *cache) getOrNew(ctx context.Context, pKey *cKey, newIfAbsent newFunc) (entry *cEntry, isNew bool) {
v, ok := c.get(pKey)
if ok {
return v, false
}
return c.newIfAbsent(ctx, pKey, newIfAbsent)
}
// newIfAbsent create entry using given "creator" if the key doesn't exist. otherwise returns existing entry
// this method is goroutine-safe
func (c *cache) newIfAbsent(ctx context.Context, pKey *cKey, creator newFunc) (entry *cEntry, isNew bool) {
c.mtx.Lock()
defer c.mtx.Unlock()
if v, ok := c.getValue(pKey); ok && !v.isInvalidated() || creator == nil {
return v, false
}
v := creator(ctx, pKey)
c.setValue(pKey, v)
return v, true
}
// set is goroutine-safe
func (c *cache) set(pKey *cKey, v *cEntry) {
c.mtx.Lock()
defer c.mtx.Unlock()
c.setValue(pKey, v)
}
// get is goroutine-safe
func (c *cache) get(pKey *cKey) (*cEntry, bool) {
c.mtx.RLock()
defer c.mtx.RUnlock()
if v, ok := c.getValue(pKey); ok && !v.isInvalidated() {
return v, ok
}
return nil, false
}
// getValue not goroutine-safe
func (c *cache) getValue(pKey *cKey) (*cEntry, bool) {
if v, ok := c.store[*pKey]; ok && v != nil {
return v, true
}
return nil, false
}
// setValue not goroutine-safe
func (c *cache) setValue(pKey *cKey, v *cEntry) {
if v == nil {
delete(c.store, *pKey)
} else {
c.store[*pKey] = v
c.deleteInvalidatedValues()
}
}
// deleteInvalidatedValues remove given keys
// this method is not goroutine-safe
func (c *cache) deleteInvalidatedValues() {
for k, v := range c.store {
if v.isInvalidated() {
c.setValue(&k, nil)
}
}
}
func (c *cache) startReaper(interval time.Duration) {
c.reaper = time.NewTicker(interval)
go func() {
for {
select {
case <-c.reaper.C:
c.evict()
}
}
}()
}
func (c *cache) evict() {
// step 1, go through the store, find loaded entries (using atomic flag) and mark them invalidated if expired (with R lock)
func() {
c.mtx.RLock()
defer c.mtx.RUnlock()
for _, v := range c.store {
if !v.isInvalidated() && v.isLoaded() && v.isExpired() {
v.invalidate()
}
}
}()
// step 2, remove invalidated entries (with W lock)
c.mtx.Lock()
defer c.mtx.Unlock()
c.deleteInvalidatedValues()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scope
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
const (
FxGroup = "security-scope"
)
var (
scopeManager ScopeManager
)
var (
ErrNotInitialized = fmt.Errorf("security scope manager is not initialied yet")
ErrMissingDefaultSysAccount = fmt.Errorf("unable to switch security scope: default system account is not configured")
ErrMissingUser = fmt.Errorf("unable to switch security scope: either username or user ID is required when not using default system account")
ErrNotCurrentlyAuthenticated = fmt.Errorf("unable to switch security scope without system account: current context is not authenticated")
ErrUserIdAndUsernameExclusive = fmt.Errorf("invalid security scope option: username and user ID are exclusive")
ErrTenantIdAndNameExclusive = fmt.Errorf("invalid security scope option: tenant name and tenant ID are exclusive")
)
type Options func(*Scope)
type Scope struct {
username string // target username
userId string // target userId
tenantExternalId string // target tenantExternalId
tenantId string // target tenantId
time time.Time
useSysAcct bool
cacheKey *cKey
}
func New(opts ...Options) *Scope {
scope := Scope{
time: time.Now(),
}
for _, fn := range opts {
fn(&scope)
}
return &scope
}
func (s Scope) String() string {
user := s.userId
if s.username != "" {
user = s.username
}
tenant := s.tenantExternalId
if s.tenantId != "" {
tenant = s.tenantId
}
if tenant == "" {
return user
}
return fmt.Sprintf("%s@%s", user, tenant)
}
func (s *Scope) Do(ctx context.Context, fn func(ctx context.Context)) (err error) {
c, e := s.start(ctx)
if e != nil {
return e
}
defer func() {
switch e := recover().(type) {
case nil:
case error:
err = e
default:
err = fmt.Errorf("%v", e)
}
}()
fn(c)
scopeManager.End(c)
return nil
}
func (s *Scope) start(ctx context.Context) (context.Context, error) {
if scopeManager == nil {
return nil, ErrNotInitialized
}
return scopeManager.StartScope(ctx, s)
}
func (s *Scope) validate(_ context.Context) error {
if s.username != "" && s.userId != "" {
return ErrUserIdAndUsernameExclusive
}
if s.tenantExternalId != "" && s.tenantId != "" {
return ErrTenantIdAndNameExclusive
}
return nil
}
type ScopeManager interface {
StartScope(ctx context.Context, scope *Scope) (context.Context, error)
Start(ctx context.Context, opts...Options) (context.Context, error)
End(ctx context.Context) context.Context
}
/**************************
Convenient Functions
**************************/
// Do invoke given function in a security scope specified by Options
// e.g.:
// scope.Do(ctx, func(ctx context.Context) {
// // do something with ctx
// }, scope.WithUsername("a-user"), scope.UseSystemAccount())
func Do(ctx context.Context, fn func(ctx context.Context), opts ...Options) error {
return New(opts...).Do(ctx, fn)
}
func Describe(ctx context.Context) string {
scope, ok := ctx.Value(ctxKeyScope).(*Scope)
if !ok {
return "no scope"
}
return scope.String()
}
/**************************
TestHooks
**************************/
//goland:noinspection GoNameStartsWithPackageName
type ScopeOperationHook func(ctx context.Context, scope *Scope) context.Context
type ManagerCustomizer interface {
Customize() []ManagerOptions
}
type ManagerCustomizerFunc func() []ManagerOptions
func (fn ManagerCustomizerFunc) Customize() []ManagerOptions {
return fn()
}
func BeforeStartHook(hook ScopeOperationHook) ManagerOptions {
return func(opt *managerOption) {
opt.BeforeStartHooks = append(opt.BeforeStartHooks, hook)
}
}
func AfterEndHook(hook ScopeOperationHook) ManagerOptions {
return func(opt *managerOption) {
opt.AfterEndHooks = append(opt.AfterEndHooks, hook)
}
}
/**************************
Context
**************************/
type rollbackCtxKey struct{}
type scopeCtxKey struct{}
var ctxKeyRollback = rollbackCtxKey{}
var ctxKeyScope = scopeCtxKey{}
// scopedContext helps managerBase to backtrace context used for managerBase.DoStartScope and keep track of Scope
func newScopedContext(parent context.Context, scope *Scope, auth security.Authentication) context.Context {
scoped := utils.NewMutableContext(parent, func(key interface{}) interface{} {
switch key {
case ctxKeyRollback:
return parent
case ctxKeyScope:
return scope
default: return nil
}
})
security.MustSet(scoped, auth)
return scoped
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scope
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/integrate/security/seclient"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
type ManagerOptions func(opt *managerOption)
type managerOption struct {
Client seclient.AuthenticationClient
TokenStoreReader oauth2.TokenStoreReader
BackOffPeriod time.Duration
GuaranteedValidity time.Duration
KnownCredentials map[string]string
SystemAccounts utils.StringSet
DefaultSystemAccount string
BeforeStartHooks []ScopeOperationHook
AfterEndHooks []ScopeOperationHook
}
// defaultScopeManager always first attempt to login as system account and then switch to destination security context
type defaultScopeManager struct {
managerBase
client seclient.AuthenticationClient
knownCredentials map[string]string
systemAccounts utils.StringSet
defaultSysAcct string
defaultSysAcctKey cKey
}
func newDefaultScopeManager(opts ...ManagerOptions) *defaultScopeManager {
opt := managerOption{}
for _, fn := range opts {
fn(&opt)
}
return &defaultScopeManager{
managerBase: managerBase{
cache: newCache(),
tokenStoreReader: opt.TokenStoreReader,
failureBackOff: opt.BackOffPeriod,
guaranteedValidity: opt.GuaranteedValidity,
beforeStartHooks: opt.BeforeStartHooks,
afterEndHooks: opt.AfterEndHooks,
},
client: opt.Client,
knownCredentials: opt.KnownCredentials,
systemAccounts: opt.SystemAccounts,
defaultSysAcct: opt.DefaultSystemAccount,
defaultSysAcctKey: cKey{
username: opt.DefaultSystemAccount,
},
}
}
func (m *defaultScopeManager) StartScope(ctx context.Context, scope *Scope) (context.Context, error) {
if e := m.prepareScope(ctx, scope); e != nil {
return nil, e
}
switch {
case scope.useSysAcct:
return m.managerBase.DoStartScope(ctx, scope, m.authWithSysAcct)
default:
return m.managerBase.DoStartScope(ctx, scope, m.authWithoutSysAcct)
}
}
func (m *defaultScopeManager) Start(ctx context.Context, opts...Options) (context.Context, error) {
scope := New(opts...)
return m.StartScope(ctx, scope)
}
// prepareScope perform some validation and prepare scope with proper cache key and other attributes
func (m *defaultScopeManager) prepareScope(ctx context.Context, scope *Scope) error {
if e := scope.validate(ctx); e != nil {
return e
}
switch {
case scope.useSysAcct:
return m.prepareScopeWithSysAcct(ctx, scope)
default:
return m.prepareScopeWithoutSysAcct(ctx, scope)
}
}
// prepareScopeWithSysAcct prepare scope with proper cache key and fill other default attributes.
// This mode ignores current authenticated user,
// and, if not specified, the target username is set to default system account
func (m *defaultScopeManager) prepareScopeWithSysAcct(ctx context.Context, scope *Scope) error {
if scope.username == "" && scope.userId == "" {
// user not specified
if m.defaultSysAcct == "" {
return ErrMissingDefaultSysAccount
}
scope.username = m.defaultSysAcct
}
currAuth := security.Get(ctx)
m.normalizeTargetUser(currAuth, scope)
m.prepareCacheKey(scope, "")
return nil
}
// prepareScopeWithoutSysAcct prepare scope with proper cache key and fill other default attributes.
// This mode requires given context bears fully authenticated user, and the target username/userId is mandatory
func (m *defaultScopeManager) prepareScopeWithoutSysAcct(ctx context.Context, scope *Scope) error {
currAuth := security.Get(ctx)
currUsername, _, e := m.resolveUser(currAuth)
if e != nil {
return ErrNotCurrentlyAuthenticated
}
if scope.username == "" && scope.userId == "" {
scope.username = currUsername
}
m.normalizeTargetUser(currAuth, scope)
m.prepareCacheKey(scope, currUsername)
return nil
}
// authWithSysAcct is an authenticateFunc which is invoked by loadFunc in a separate goroutine
// therefore it's safe to call managerBase.GetOrAuthenticate again without deadlocking.
// This auth method would try direct password login (if password is known),
// then fallback to 2 stepped context switching:
// 1. try switch to default system account (may involve password login using system accoutn credentials)
// 2. call switch user/tenant API with system account's access token
func (m *defaultScopeManager) authWithSysAcct(ctx context.Context, pKey *cKey) (security.Authentication, error) {
if pKey == nil {
return nil, fmt.Errorf("[Internal Error] cache key is nil")
}
// first, attempt password login
if r, e := m.passwordLogin(ctx, pKey); e != nil {
return nil, e
} else if r != nil && r.Token != nil {
return m.convertToAuthentication(ctx, r)
}
// then attempt to do switch context using system account
// note that at this point, it's guaranteed that the given pKey is not default sys account key
auth, e := m.GetOrAuthenticate(ctx, &m.defaultSysAcctKey, time.Now().UTC(), m.authWithSysAcct)
if e != nil {
return nil, e
}
r, e := m.switchContext(ctx, pKey, auth)
if e != nil {
return nil, e
} else if r != nil && r.Token != nil {
return m.convertToAuthentication(ctx, r)
}
return auth, nil
}
// authWithoutSysAcct is an authenticateFunc which is invoked by loadFunc in a separate goroutine
// therefore it's safe to call managerBase.GetOrAuthenticate again without deadlocking
// context switching by calling switch user/tenant API with current access token
func (m *defaultScopeManager) authWithoutSysAcct(ctx context.Context, pKey *cKey) (security.Authentication, error) {
if pKey == nil {
return nil, fmt.Errorf("[Internal Error] cache key is nil")
} else if m.systemAccounts.Has(pKey.username) {
return nil, fmt.Errorf("[Internal Error] cannot switch to system account without UseSystemAccount() option")
}
// attempt to use switch context with current auth
auth := security.Get(ctx)
r, e := m.switchContext(ctx, pKey, auth)
if e != nil {
return nil, e
} else if r != nil && r.Token != nil {
return m.convertToAuthentication(ctx, r)
}
//
return auth, nil
}
func (m *defaultScopeManager) credentialsLookup(pKey *cKey) (password string, found bool) {
if pKey.username == "" {
return
}
password, found = m.knownCredentials[pKey.username]
return
}
// passwordLogin perform password login if credentials is available.
// it returns nil, nil if no credentials is found
func (m *defaultScopeManager) passwordLogin(ctx context.Context, pKey *cKey) (*seclient.Result, error) {
p, ok := m.credentialsLookup(pKey)
if !ok {
// password not available
return nil, nil
}
authOpts := []seclient.AuthOptions{
seclient.WithCredentials(pKey.username, p),
}
if pKey.tenantExternalId != "" || pKey.tenantId != "" {
authOpts = append(authOpts, seclient.WithTenant(pKey.tenantId, pKey.tenantExternalId))
}
return m.client.PasswordLogin(ctx, authOpts...)
}
// switchContext perform switch user or switch tenant
// it returns nil, nil if target context is identical as given auth (same user and same tenant)
func (m *defaultScopeManager) switchContext(ctx context.Context, pKey *cKey, auth security.Authentication) (*seclient.Result, error) {
if _, ok := m.credentialsLookup(pKey); ok {
return nil, fmt.Errorf("user [%s] is configured to use password login only", pKey.username)
}
authOpts := []seclient.AuthOptions{
seclient.WithAuthentication(auth),
}
if pKey.tenantExternalId != "" || pKey.tenantId != ""{
authOpts = append(authOpts, seclient.WithTenant(pKey.tenantId, pKey.tenantExternalId))
}
if m.isSameUser(pKey.username, pKey.userId, auth) {
// switch tenant
if m.isSameTenant(pKey.tenantExternalId, pKey.tenantId, auth) {
return nil, nil
} else {
return m.client.SwitchTenant(ctx, authOpts...)
}
} else {
// switch user
authOpts = append(authOpts, seclient.WithUser(pKey.userId, pKey.username))
return m.client.SwitchUser(ctx, authOpts...)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scope
func WithUsername(username string) Options {
return func(s *Scope) {
s.username = username
s.userId = ""
}
}
func WithUserId(userId string) Options {
return func(s *Scope) {
s.username = ""
s.userId = userId
}
}
func WithTenantId(tenantId string) Options {
return func(s *Scope) {
s.tenantExternalId = ""
s.tenantId = tenantId
}
}
func WithTenantExternalId(tenantExternalId string) Options {
return func(s *Scope) {
s.tenantExternalId = tenantExternalId
s.tenantId = ""
}
}
func UseSystemAccount() Options {
return func(s *Scope) {
s.useSysAcct = true
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scope
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
securityint "github.com/cisco-open/go-lanai/pkg/integrate/security"
"github.com/cisco-open/go-lanai/pkg/integrate/security/seclient"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"go.uber.org/fx"
"time"
)
var logger = log.New("SEC.Scope")
var Module = &bootstrap.Module{
Name: "security scope",
Precedence: bootstrap.SecurityIntegrationPrecedence,
Options: []fx.Option{
fx.Provide(provideDefaultScopeManager),
fx.Provide(tracingProvider()),
fx.Invoke(configureScopeManagers),
},
}
func Use() {
seclient.Use()
bootstrap.Register(Module)
}
// FxManagerCustomizer takes providers of ManagerCustomizer and wrap them with FxGroup
func FxManagerCustomizer(constructor interface{}) fx.Annotated {
return fx.Annotated{
Group: FxGroup,
Target: constructor,
}
}
type defaultDI struct {
fx.In
AuthClient seclient.AuthenticationClient `optional:"true"`
Properties securityint.SecurityIntegrationProperties `optional:"true"`
TokenStoreReader oauth2.TokenStoreReader `optional:"true"`
Customizers []ManagerCustomizer `group:"security-scope"`
}
func provideDefaultScopeManager(di defaultDI) (ScopeManager, error) {
if di.TokenStoreReader == nil || di.AuthClient == nil {
return nil, fmt.Errorf(`security scope managers requires "resserver" and "seclient", but not configured`)
}
// default options
opts := []ManagerOptions{
func(opt *managerOption) {
opt.Client = di.AuthClient
opt.TokenStoreReader = di.TokenStoreReader
opt.BackOffPeriod = time.Duration(di.Properties.FailureBackOff)
opt.GuaranteedValidity = time.Duration(di.Properties.GuaranteedValidity)
// parse accounts
credentials := map[string]string{}
sysAccts := utils.NewStringSet()
if di.Properties.Accounts.Default.Username != "" {
opt.DefaultSystemAccount = di.Properties.Accounts.Default.Username
credentials[di.Properties.Accounts.Default.Username] = di.Properties.Accounts.Default.Password
sysAccts.Add(di.Properties.Accounts.Default.Username)
}
// TBD, this is consistent behavior from java impl. Such configuration allows dev-ops to give
// special treatment on certain accounts. Since we don't know any use case of this feature at
// the time of writing this code, we temporarily disabled it, but keep the code for reference.
//for _, acct := range di.Properties.Accounts.Additional {
// if acct.UName == "" || acct.Password == "" {
// continue
// }
// credentials[acct.UName] = acct.Password
// if acct.SystemAccount {
// sysAccts.Add(acct.UName)
// }
//}
opt.KnownCredentials = credentials
opt.SystemAccounts = sysAccts
},
}
// customizers
for _, c := range di.Customizers {
opts = append(opts, c.Customize()...)
}
return newDefaultScopeManager(opts...), nil
}
func configureScopeManagers(EffectiveScopeManager ScopeManager) {
scopeManager = EffectiveScopeManager
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scope
import (
"context"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"go.uber.org/fx"
)
const tracingOpName = "security"
type tracingManagerCustomizer struct {
tracer opentracing.Tracer
}
func tracingProvider() fx.Annotated {
return FxManagerCustomizer(newSecurityScopeManagerCustomizer)
}
type tracingDI struct {
fx.In
Tracer opentracing.Tracer `optional:"true"`
}
func newSecurityScopeManagerCustomizer(di tracingDI) ManagerCustomizer {
return &tracingManagerCustomizer{
tracer: di.Tracer,
}
}
func (c *tracingManagerCustomizer) Customize() []ManagerOptions {
if c.tracer == nil {
return []ManagerOptions{}
}
return []ManagerOptions{
BeforeStartHook(startSpanHook(c.tracer)),
AfterEndHook(finishSpanHook(c.tracer)),
}
}
func startSpanHook(tracer opentracing.Tracer) ScopeOperationHook {
return func(ctx context.Context, scope *Scope) context.Context {
name := tracingOpName
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCServerEnum),
}
if scope != nil {
opts = append(opts, tracing.SpanTag("sec.scope", scope.String()))
}
return tracing.WithTracer(tracer).
WithOpName(name).
WithOptions(opts...).
DescendantOrNoSpan(ctx)
}
}
func finishSpanHook(tracer opentracing.Tracer) ScopeOperationHook {
return func(ctx context.Context, _ *Scope) context.Context {
return tracing.WithTracer(tracer).FinishAndRewind(ctx)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package seclient
import (
"context"
"github.com/cisco-open/go-lanai/pkg/integrate/httpclient"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"net/http"
"net/url"
"strings"
"time"
)
const (
nonceCharset = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
type AuthClientOptions func(opt *AuthClientOption)
type AuthClientOption struct {
Client httpclient.Client
ServiceName string
Scheme string
ContextPath string
BaseUrl string
PwdLoginPath string
SwitchContextPath string
ClientId string
ClientSecret string
}
type remoteAuthClient struct {
client httpclient.Client
clientId string
clientSecret string
pwdLoginPath string
switchCtxPath string
}
func NewRemoteAuthClient(opts ...AuthClientOptions) (AuthenticationClient, error) {
opt := AuthClientOption{
PwdLoginPath: "/v2/token",
SwitchContextPath: "/v2/token",
}
for _, fn := range opts {
fn(&opt)
}
// prepare httpclient
var client httpclient.Client
var err error
if opt.BaseUrl != "" {
client, err = opt.Client.WithBaseUrl(opt.BaseUrl)
} else {
client, err = opt.Client.WithService(opt.ServiceName, func(sdOpt *httpclient.SDOption) {
sdOpt.Scheme = opt.Scheme
sdOpt.ContextPath = opt.ContextPath
})
}
if err != nil {
return nil, err
}
return &remoteAuthClient{
client: client.WithConfig(&httpclient.ClientConfig{
// Note: we don't want access token passthrough
BeforeHooks: []httpclient.BeforeHook{},
Logger: logger,
MaxRetries: 2,
Timeout: 30 * time.Second,
Logging: httpclient.LoggingConfig{
Level: log.LevelDebug,
//DetailsLevel: httpclient.LogDetailsLevelMinimum,
//SanitizeHeaders: utils.NewStringSet(),
},
}),
clientId: opt.ClientId,
clientSecret: opt.ClientSecret,
pwdLoginPath: opt.PwdLoginPath,
switchCtxPath: opt.SwitchContextPath,
}, nil
}
func (c *remoteAuthClient) ClientCredentials(ctx context.Context, opts ...AuthOptions) (*Result, error) {
opt := c.option(opts)
reqOpts := []httpclient.RequestOptions{
c.withClientAuth(opt),
httpclient.WithHeader(httpclient.HeaderContentType, httpclient.MediaTypeFormUrlEncoded),
httpclient.WithUrlEncodedBody(WithNonEmptyURLValues(url.Values{
oauth2.ParameterGrantType: {oauth2.GrantTypeClientCredentials},
oauth2.ClaimScope: {strings.Join(opt.Scopes, " ")},
})),
}
reqOpts = append(reqOpts, c.reqOptionsForTenancy(opt)...)
// prepare request
req := httpclient.NewRequest(c.pwdLoginPath, http.MethodPost, reqOpts...)
// send request and parse response
body := oauth2.NewDefaultAccessToken("")
resp, e := c.client.Execute(ctx, req, httpclient.JsonBody(body))
return c.handleResponse(resp, e)
}
func (c *remoteAuthClient) PasswordLogin(ctx context.Context, opts ...AuthOptions) (*Result, error) {
opt := c.option(opts)
nonce := c.generateNonce(10)
reqOpts := []httpclient.RequestOptions{
httpclient.WithParam(oauth2.ParameterGrantType, oauth2.GrantTypePassword),
httpclient.WithParam(oauth2.ParameterUsername, opt.Username),
c.withClientAuth(opt),
httpclient.WithUrlEncodedBody(WithNonEmptyURLValues(url.Values{
oauth2.ParameterPassword: {opt.Password},
oauth2.ParameterNonce: {nonce},
oauth2.ClaimScope: {strings.Join(opt.Scopes, " ")},
})),
}
reqOpts = append(reqOpts, c.reqOptionsForTenancy(opt)...)
// prepare request
req := httpclient.NewRequest(c.pwdLoginPath, http.MethodPost, reqOpts...)
// send request and parse response
body := oauth2.NewDefaultAccessToken("")
resp, e := c.client.Execute(ctx, req, httpclient.JsonBody(body))
return c.handleResponse(resp, e)
}
func (c *remoteAuthClient) SwitchUser(ctx context.Context, opts ...AuthOptions) (*Result, error) {
opt := c.option(opts)
nonce := c.generateNonce(10)
reqOpts := []httpclient.RequestOptions{
httpclient.WithParam(oauth2.ParameterGrantType, oauth2.GrantTypeSwitchUser),
c.withClientAuth(opt),
httpclient.WithUrlEncodedBody(WithNonEmptyURLValues(url.Values{
oauth2.ParameterAccessToken: {opt.AccessToken},
oauth2.ParameterNonce: {nonce},
oauth2.ClaimScope: {strings.Join(opt.Scopes, " ")},
})),
}
reqOpts = append(reqOpts, c.reqOptionsForSwitchUser(opt)...)
reqOpts = append(reqOpts, c.reqOptionsForTenancy(opt)...)
// prepare request
req := httpclient.NewRequest(c.switchCtxPath, http.MethodPost, reqOpts...)
// send request and parse response
body := oauth2.NewDefaultAccessToken("")
resp, e := c.client.Execute(ctx, req, httpclient.JsonBody(body))
return c.handleResponse(resp, e)
}
func (c *remoteAuthClient) SwitchTenant(ctx context.Context, opts ...AuthOptions) (*Result, error) {
opt := c.option(opts)
nonce := c.generateNonce(10)
reqOpts := []httpclient.RequestOptions{
httpclient.WithParam(oauth2.ParameterGrantType, oauth2.GrantTypeSwitchTenant),
c.withClientAuth(opt),
httpclient.WithUrlEncodedBody(WithNonEmptyURLValues(url.Values{
oauth2.ParameterAccessToken: {opt.AccessToken},
oauth2.ParameterNonce: {nonce},
oauth2.ClaimScope: {strings.Join(opt.Scopes, " ")},
})),
}
reqOpts = append(reqOpts, c.reqOptionsForTenancy(opt)...)
// prepare request
req := httpclient.NewRequest(c.switchCtxPath, http.MethodPost, reqOpts...)
// send request and parse response
body := oauth2.NewDefaultAccessToken("")
resp, e := c.client.Execute(ctx, req, httpclient.JsonBody(body))
return c.handleResponse(resp, e)
}
func (c *remoteAuthClient) option(opts []AuthOptions) *AuthOption {
opt := AuthOption{}
for _, fn := range opts {
fn(&opt)
}
return &opt
}
func (c *remoteAuthClient) reqOptionsForTenancy(opt *AuthOption) []httpclient.RequestOptions {
ret := make([]httpclient.RequestOptions, 0, 2)
if opt.TenantId != "" {
ret = append(ret, httpclient.WithParam(oauth2.ParameterTenantId, opt.TenantId))
}
if opt.TenantExternalId != "" {
ret = append(ret, httpclient.WithParam(oauth2.ParameterTenantExternalId, opt.TenantExternalId))
}
return ret
}
func (c *remoteAuthClient) reqOptionsForSwitchUser(opt *AuthOption) []httpclient.RequestOptions {
ret := make([]httpclient.RequestOptions, 0, 2)
if opt.Username != "" {
ret = append(ret, httpclient.WithParam(oauth2.ParameterSwitchUsername, opt.Username))
}
if opt.UserId != "" {
ret = append(ret, httpclient.WithParam(oauth2.ParameterSwitchUserId, opt.UserId))
}
return ret
}
func (c *remoteAuthClient) handleResponse(resp *httpclient.Response, e error) (*Result, error) {
if e != nil {
return nil, e
}
token := resp.Body.(oauth2.AccessToken)
return &Result{
//Request: nil,
Token: token,
}, nil
}
func (c *remoteAuthClient) generateNonce(length int) string {
return utils.RandomString(length)
}
// withClientAuth will return a requestOption based off of WithBasicAuth, but
// use the clientID from the AuthOptions. If the AuthOption.ClientID is empty, then
// it will return WithBasicAuth using the fallback remoteAuthClient.clientId and secret instead
func (c *remoteAuthClient) withClientAuth(opt *AuthOption) httpclient.RequestOptions {
clientID := opt.ClientID
secret := opt.ClientSecret
if clientID == "" {
clientID = c.clientId
secret = c.clientSecret
}
return httpclient.WithBasicAuth(clientID, secret)
}
// WithNonEmptyURLValues will accept a map[key][values] and convert it to a url.Values.
// The function will check that the values, typed []string has a length > 0. Otherwise,
// will not insert the key into the url.Values
func WithNonEmptyURLValues(mappedValues map[string][]string) url.Values {
urlValues := url.Values{}
for valueKey, values := range mappedValues {
var nonEmptyValues []string
for _, value := range values {
if value != "" {
nonEmptyValues = append(nonEmptyValues, value)
}
}
if len(nonEmptyValues) > 0 {
urlValues[valueKey] = nonEmptyValues
}
}
return urlValues
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package seclient
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
)
type AuthOptions func(opt *AuthOption)
type AuthOption struct {
Password string // Password is used by password login
AccessToken string // AccessToken is used by switch user/tenant
Username string // Username is used by password login and switch user
UserId string // UserId is used by switch user
TenantId string // TenantId is used by password login and switch user/tenant
TenantExternalId string // TenantExternalId is used by password login and switch user/tenant
Scopes []string // OAuth Scopes option
ClientID string // ClientID that is used for the client credentials auth flow
ClientSecret string // ClientSecret that is used for the client credentials auth flow
}
type AuthenticationClient interface {
PasswordLogin(ctx context.Context, opts ...AuthOptions) (*Result, error)
ClientCredentials(ctx context.Context, opts ...AuthOptions) (*Result, error)
SwitchUser(ctx context.Context, opts ...AuthOptions) (*Result, error)
SwitchTenant(ctx context.Context, opts ...AuthOptions) (*Result, error)
}
type Result struct {
Token oauth2.AccessToken
}
/****************************
AuthOptions
****************************/
func WithCredentials(username, password string) AuthOptions {
return func(opt *AuthOption) {
opt.Username = username
opt.Password = password
}
}
func WithCurrentSecurity(ctx context.Context) AuthOptions {
return WithAuthentication(security.Get(ctx))
}
func WithAuthentication(auth security.Authentication) AuthOptions {
oauth, ok := auth.(oauth2.Authentication)
if !ok {
return noop()
}
return WithAccessToken(oauth.AccessToken().Value())
}
func WithAccessToken(accessToken string) AuthOptions {
return func(opt *AuthOption) {
opt.AccessToken = accessToken
}
}
// WithTenant create an options that specify tenant by either tenantId or tenantExternalId
// username and userId are exclusive, cannot be both empty
func WithTenant(tenantId string, tenantExternalId string) AuthOptions {
if tenantId != "" {
return WithTenantId(tenantId)
} else {
return WithTenantExternalId(tenantExternalId)
}
}
func WithTenantId(tenantId string) AuthOptions {
return func(opt *AuthOption) {
opt.TenantId = tenantId
opt.TenantExternalId = ""
}
}
func WithTenantExternalId(tenantExternalId string) AuthOptions {
return func(opt *AuthOption) {
opt.TenantId = ""
opt.TenantExternalId = tenantExternalId
}
}
// WithUser create an options that specify user by either username or userId
// username and userId are exclusive, cannot be both empty
func WithUser(userId string, username string) AuthOptions {
if username != "" {
return WithUsername(username)
} else {
return WithUserId(userId)
}
}
func WithUsername(username string) AuthOptions {
return func(opt *AuthOption) {
opt.Username = username
opt.UserId = ""
}
}
func WithUserId(userId string) AuthOptions {
return func(opt *AuthOption) {
opt.Username = ""
opt.UserId = userId
}
}
func WithScopes(scopes ...string) AuthOptions {
return func(opt *AuthOption) {
opt.Scopes = scopes
}
}
func WithClientAuth(clientID, secret string) AuthOptions {
return func(opt *AuthOption) {
opt.ClientID = clientID
opt.ClientSecret = secret
}
}
func noop() func(opt *AuthOption) {
return func(_ *AuthOption) {
// noop
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package seclient
import (
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/integrate/httpclient"
securityint "github.com/cisco-open/go-lanai/pkg/integrate/security"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
var logger = log.New("SEC.Client")
var Module = &bootstrap.Module{
Name: "auth-client",
Precedence: bootstrap.SecurityIntegrationPrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(securityint.DefaultConfigFS),
fx.Provide(securityint.BindSecurityIntegrationProperties),
fx.Provide(provideAuthClient),
},
}
func Use() {
httpclient.Use()
bootstrap.Register(Module)
}
type clientDI struct {
fx.In
HttpClient httpclient.Client
Properties securityint.SecurityIntegrationProperties
}
func provideAuthClient(di clientDI) (AuthenticationClient, error) {
return NewRemoteAuthClient(func(opt *AuthClientOption) {
opt.Client = di.HttpClient
opt.ClientId = di.Properties.Client.ClientId
opt.ClientSecret = di.Properties.Client.ClientSecret
opt.BaseUrl = di.Properties.Endpoints.BaseUrl
opt.ServiceName = di.Properties.Endpoints.ServiceName
opt.Scheme = di.Properties.Endpoints.Scheme
opt.ContextPath = di.Properties.Endpoints.ContextPath
opt.PwdLoginPath = di.Properties.Endpoints.PasswordLogin
opt.SwitchContextPath = di.Properties.Endpoints.SwitchContext
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"errors"
"fmt"
"github.com/IBM/sarama"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/utils/loop"
"io"
"math"
"strings"
"sync"
"time"
)
const (
errTmplProducerExists = `producer for topic %s already exist. please use the existing instance`
errTmplSubscriberExists = `subscriber for topic %s already exist. please use the existing instance`
errTmplConsumerGroupExists = `consumer group for topic %s already exist. please use the existing instance`
errTmplCannotConnectBrokers = `unable to connect to Kafka brokers %v: %v`
)
type SaramaKafkaBinder struct {
sync.RWMutex
appConfig bootstrap.ApplicationConfig
properties *KafkaProperties
brokers []string
initOnce sync.Once
startOnce sync.Once
defaults bindingConfig
producerInterceptors []ProducerMessageInterceptor
consumerInterceptors []ConsumerDispatchInterceptor
handlerInterceptors []ConsumerHandlerInterceptor
monitor *loop.Loop
tlsCertsManager certs.Manager
// TODO consider mutex lock for following fields
producers map[string]BindingLifecycle
subscribers map[string]BindingLifecycle
consumerGroups map[string]BindingLifecycle
// following fields are protected by mutex lock
globalClient sarama.Client
adminClient sarama.ClusterAdmin
tlsSource certs.Source
provisioner *saramaTopicProvisioner
closed bool
monitorCtx context.Context
monitorCancelFunc context.CancelFunc
}
type BinderOptions func(opt *BinderOption)
type BinderOption struct {
ApplicationConfig bootstrap.ApplicationConfig
Properties KafkaProperties
ProducerInterceptors []ProducerMessageInterceptor
ConsumerInterceptors []ConsumerDispatchInterceptor
HandlerInterceptors []ConsumerHandlerInterceptor
TLSCertsManager certs.Manager
}
func NewBinder(ctx context.Context, opts ...BinderOptions) *SaramaKafkaBinder {
opt := BinderOption{
ProducerInterceptors: []ProducerMessageInterceptor{mimeTypeProducerInterceptor{}},
}
for _, fn := range opts {
fn(&opt)
}
properties := opt.Properties
s := &SaramaKafkaBinder{
appConfig: opt.ApplicationConfig,
properties: &properties,
brokers: opt.Properties.Brokers,
producerInterceptors: opt.ProducerInterceptors,
consumerInterceptors: opt.ConsumerInterceptors,
handlerInterceptors: opt.HandlerInterceptors,
monitor: loop.NewLoop(),
producers: make(map[string]BindingLifecycle),
subscribers: make(map[string]BindingLifecycle),
consumerGroups: make(map[string]BindingLifecycle),
tlsCertsManager: opt.TLSCertsManager,
}
if e := s.Initialize(ctx); e != nil {
panic(e)
}
return s
}
func (b *SaramaKafkaBinder) prepareDefaults(ctx context.Context, saramaDefaults *sarama.Config) {
b.defaults = bindingConfig{
name: "default",
properties: BindingProperties{},
sarama: *saramaDefaults,
msgLogger: newSaramaMessageLogger(),
producer: producerConfig{
keyEncoder: binaryEncoder{},
interceptors: b.producerInterceptors,
provisioning: topicConfig{
autoCreateTopic: true,
autoAddPartitions: true,
allowLowerPartitions: true,
partitionCount: 1,
replicationFactor: 1,
},
},
consumer: consumerConfig{
dispatchInterceptors: b.consumerInterceptors,
handlerInterceptors: b.handlerInterceptors,
msgLogger: newSaramaMessageLogger(),
},
}
// try load default properties
if e := b.appConfig.Bind(&b.defaults.properties, ConfigKafkaDefaultBindingPrefix); e != nil {
logger.WithContext(ctx).Infof("default kafka binding properties [%s.*] is not configured")
}
}
// CloseProducer release resources for dynamic producers
func (b *SaramaKafkaBinder) CloseProducer(ctx context.Context, topic string) {
if p, ok := b.producers[topic]; ok {
if e := p.Close(); e != nil {
logger.WithContext(ctx).Errorf("error while closing kafka producer: %v", e)
}
}
delete(b.producers, topic)
}
func (b *SaramaKafkaBinder) Produce(topic string, options ...ProducerOptions) (Producer, error) {
if p, ok := b.producers[topic]; ok && !p.Closed() {
logger.Warnf(errTmplProducerExists, topic)
return nil, NewKafkaError(ErrorCodeProducerExists, errTmplProducerExists, topic)
}
// apply defaults and options
cfg := b.defaults // make a copy
cfg.name = strings.ToLower(topic)
for _, optionFunc := range options {
optionFunc(&cfg)
}
// load and apply properties
props := b.loadProperties(cfg.name)
WithProducerProperties(&props.Producer)(&cfg)
if e := b.provisioner.provisionTopic(topic, &cfg); e != nil {
return nil, e
}
p, err := newSaramaProducer(topic, b.brokers, &cfg)
if err != nil {
return nil, err
}
b.producers[topic] = p
return p, b.tryScheduleStart(p)
}
func (b *SaramaKafkaBinder) Subscribe(topic string, options ...ConsumerOptions) (Subscriber, error) {
if s, ok := b.subscribers[topic]; ok && !s.Closed() {
logger.Warnf(errTmplSubscriberExists, topic)
return nil, NewKafkaError(ErrorCodeConsumerExists, errTmplSubscriberExists, topic)
}
// apply defaults and options
cfg := b.defaults // make a copy
cfg.name = strings.ToLower(topic)
for _, optionFunc := range options {
optionFunc(&cfg)
}
// load and apply properties
props := b.loadProperties(cfg.name)
WithConsumerProperties(&props.Consumer)(&cfg)
sub, err := newSaramaSubscriber(topic, b.brokers, &cfg, b.provisioner)
if err != nil {
return nil, err
}
b.subscribers[topic] = sub
return sub, b.tryScheduleStart(sub)
}
func (b *SaramaKafkaBinder) Consume(topic string, group string, options ...ConsumerOptions) (GroupConsumer, error) {
if c, ok := b.consumerGroups[topic]; ok && !c.Closed() {
logger.Warnf(errTmplConsumerGroupExists, topic)
return nil, NewKafkaError(ErrorCodeConsumerExists, errTmplConsumerGroupExists, topic)
}
// apply defaults and options
cfg := b.defaults // make a copy
cfg.name = strings.ToLower(topic)
for _, optionFunc := range options {
optionFunc(&cfg)
}
// load and apply properties
props := b.loadProperties(cfg.name)
WithConsumerProperties(&props.Consumer)(&cfg)
cg, err := newSaramaGroupConsumer(topic, group, b.brokers, &cfg, b.provisioner)
if err != nil {
return nil, err
}
b.consumerGroups[topic] = cg
return cg, b.tryScheduleStart(cg)
}
func (b *SaramaKafkaBinder) ListTopics() (topics []string) {
topics = make([]string, 0, len(b.producers)+len(b.subscribers)+len(b.consumerGroups))
for t := range b.producers {
topics = append(topics, t)
}
for t := range b.subscribers {
topics = append(topics, t)
}
for t := range b.consumerGroups {
topics = append(topics, t)
}
return topics
}
func (b *SaramaKafkaBinder) Client() sarama.Client {
return b.globalClient
}
// Initialize implements BinderLifecycle, prepare for use, negotiate default configs, etc.
func (b *SaramaKafkaBinder) Initialize(ctx context.Context) (err error) {
b.initOnce.Do(func() {
b.Lock()
defer b.Unlock()
if b.closed {
err = ErrorStartClosedBinding.WithMessage("attempt to initialize Binder after shutdown")
return
}
cfg, e := defaultSaramaConfig(ctx, b.properties)
if e != nil {
err = NewKafkaError(ErrorCodeBindingInternal, fmt.Sprintf("unable to create kafka config: %v", e))
logger.WithContext(ctx).Errorf("%v", err)
return
}
// config TLS if enabled
if b.properties.Net.Tls.Enable {
if b.tlsCertsManager == nil {
err = fmt.Errorf("failed to initialize Binder: TLS Auth is enabled but certificate manager is not provisioned")
return
}
b.tlsSource, err = b.tlsCertsManager.Source(ctx, certs.WithSourceProperties(&b.properties.Net.Tls.Certs))
if err != nil {
logger.WithContext(ctx).Errorf("failed to get tls provider: %s", err.Error())
return
}
cfg.Net.TLS.Enable = true
cfg.Net.TLS.Config, err = b.tlsSource.TLSConfig(ctx)
if err != nil {
logger.WithContext(ctx).Errorf("Failed to initialize Kafka binder: %v", err)
return
}
}
// prepare defaults
b.prepareDefaults(ctx, cfg)
// create a global client
b.globalClient, err = sarama.NewClient(b.brokers, cfg)
if err != nil {
err = NewKafkaError(ErrorCodeBrokerNotReachable, fmt.Sprintf(errTmplCannotConnectBrokers, b.brokers, err), err)
logger.WithContext(ctx).Errorf("%v", err)
return
}
b.adminClient, err = sarama.NewClusterAdmin(b.brokers, cfg)
if err != nil {
err = NewKafkaError(ErrorCodeBrokerNotReachable, fmt.Sprintf(errTmplCannotConnectBrokers, b.brokers, err), err)
logger.WithContext(ctx).Errorf("%v", err)
return
}
b.provisioner = &saramaTopicProvisioner{
globalClient: b.globalClientProvider,
adminClient: b.clusterAdminProvider,
}
})
return
}
// Start implements BinderLifecycle, start all bindings if not started yet (Producer, Subscriber, GroupConsumer, etc).
func (b *SaramaKafkaBinder) Start(ctx context.Context) (err error) {
b.startOnce.Do(func() {
b.Lock()
defer b.Unlock()
if b.closed {
err = ErrorStartClosedBinding.WithMessage("attempt to initialize Binder after shutdown")
return
}
b.monitorCtx, b.monitorCancelFunc = b.monitor.Run(ctx)
//nolint:contextcheck // b.monitorCtx is derived from given context
b.monitor.Repeat(b.tryStartTaskFunc(b.monitorCtx), func(opt *loop.TaskOption) {
opt.RepeatIntervalFunc = b.tryStartRepeatIntervalFunc()
})
//nolint:contextcheck // b.monitorCtx is derived from given context
go func(c context.Context) {
select {
case <-c.Done():
_ = b.Shutdown(ctx)
}
}(b.monitorCtx)
})
return nil
}
// Shutdown implements BinderLifecycle, close resources
func (b *SaramaKafkaBinder) Shutdown(ctx context.Context) error {
b.Lock()
defer b.Unlock()
defer func() { b.closed = true }()
if b.monitorCancelFunc == nil {
return nil
}
logger.WithContext(ctx).Infof("Kafka shutting down")
logger.WithContext(ctx).Debugf("stopping binding watchdog...")
b.monitorCancelFunc()
b.monitorCtx = nil
b.monitorCancelFunc = nil
logger.WithContext(ctx).Debugf("closing producers...")
for _, p := range b.producers {
if e := p.Close(); e != nil {
// since application is shutting down, we just log the error
logger.WithContext(ctx).Errorf("error while closing kafka producer: %v", e)
}
}
logger.WithContext(ctx).Debugf("closing subscribers...")
for _, p := range b.subscribers {
if e := p.Close(); e != nil {
// since application is shutting down, we just log the error
logger.WithContext(ctx).Errorf("error while closing kafka subscriber: %v", e)
}
}
logger.WithContext(ctx).Debugf("closing group consumers...")
for _, p := range b.consumerGroups {
if e := p.Close(); e != nil {
// since application is shutting down, we just log the error
logger.WithContext(ctx).Errorf("error while closing kafka consumer: %v", e)
}
}
logger.WithContext(ctx).Debugf("closing connections...")
if e := b.adminClient.Close(); e != nil {
logger.WithContext(ctx).Errorf("error while closing kafka admin client: %v", e)
}
if e := b.globalClient.Close(); e != nil {
logger.WithContext(ctx).Errorf("error while closing kafka global client: %v", e)
}
if closer, ok := b.tlsSource.(io.Closer); ok {
if e := closer.Close(); e != nil {
logger.WithContext(ctx).Errorf("error while closing tls config provider: %v", e)
}
}
logger.WithContext(ctx).Infof("Kafka connections closed")
return nil
}
func (b *SaramaKafkaBinder) Done() <-chan struct{} {
b.RLock()
defer b.RUnlock()
if b.monitorCtx != nil {
return b.monitorCtx.Done()
}
// called after "Shutdown", return a closed channel
done := make(chan struct{}, 1)
close(done)
return done
}
// loadProperties load properties for particular topic
func (b *SaramaKafkaBinder) loadProperties(name string) *BindingProperties {
prefix := ConfigKafkaBindingPrefix + "." + strings.ToLower(name)
props := b.defaults.properties // make a copy
if e := b.appConfig.Bind(&props, prefix); e != nil {
props = b.defaults.properties // make a fresh copy
}
return &props
}
func (b *SaramaKafkaBinder) globalClientProvider() (sarama.Client, error) {
return b.globalClient, nil
}
func (b *SaramaKafkaBinder) clusterAdminProvider() (sarama.ClusterAdmin, error) {
// simple test to see if admin client is still working
filter := sarama.AclFilter{
ResourceType: sarama.AclResourceTopic,
Operation: sarama.AclOperationRead,
}
_, e := b.adminClient.ListAcls(filter)
if e == nil {
return b.adminClient, nil
}
newClient, e := sarama.NewClusterAdmin(b.brokers, &b.defaults.sarama)
if e != nil {
return nil, NewKafkaError(ErrorCodeBrokerNotReachable, fmt.Sprintf(errTmplCannotConnectBrokers, b.brokers, e), e)
}
_ = b.adminClient.Close()
b.adminClient = newClient
return newClient, nil
}
// tryScheduleStart try to schedule start given BindingLifecycle using monitor loop if started, otherwise do nothing
func (b *SaramaKafkaBinder) tryScheduleStart(lc BindingLifecycle) error {
b.RLock()
defer b.RUnlock()
if b.monitorCtx != nil {
b.monitor.Do(b.tryStartSingleTaskFunc(b.monitorCtx, lc))
}
return nil
}
// tryStartSingleTaskFunc try to start given Binding
func (b *SaramaKafkaBinder) tryStartSingleTaskFunc(loopCtx context.Context, lc BindingLifecycle) loop.TaskFunc {
return func(_ context.Context, l *loop.Loop) (ret interface{}, err error) {
// we cannot use passed-in context, because this context will be cancelled as soon as this function finishes
e := lc.Start(loopCtx)
return e == nil, e
}
}
// tryStartTaskFunc try to start any registered bindings if it's not started yet
// this task should be run periodically to perform delayed start of any Subscriber or GroupConsumer
func (b *SaramaKafkaBinder) tryStartTaskFunc(loopCtx context.Context) loop.TaskFunc {
return func(_ context.Context, l *loop.Loop) (ret interface{}, err error) {
// we cannot use passed-in context, because this context will be cancelled as soon as this function finishes
allStarted := true
toProcess := []map[string]BindingLifecycle{
b.producers, b.subscribers, b.consumerGroups,
}
for _, bindings := range toProcess {
for k, lc := range bindings {
switch e := lc.Start(loopCtx); {
case errors.Is(e, ErrorStartClosedBinding):
delete(bindings, k)
case e != nil:
allStarted = false
}
}
}
return allStarted, nil
}
}
// tryStartRepeatIntervalFunc decide repeat rate of tryStartTaskFunc
// we repeat fast at beginning
// when all bindings are successfully started, we reduce the repeating rate
// S-shaped curve.
// Logistic Function https://en.wikipedia.org/wiki/Logistic_function
//
// https://en.wikipedia.org/wiki/Generalised_logistic_function
func (b *SaramaKafkaBinder) tryStartRepeatIntervalFunc() loop.RepeatIntervalFunc {
var fn func(int) time.Duration
n := -1
min := float64(b.properties.Binder.InitialHeartbeat)
max := math.Max(min, float64(b.properties.Binder.WatchdogHeartbeat))
mid := math.Max(1, b.properties.Binder.HeartbeatCurveMidpoint)
k := math.Max(0.2, b.properties.Binder.HeartbeatCurveFactor)
if float64(time.Minute) < max-min && mid >= 5 {
fn = b.logisticModel(min, max, k, mid, time.Second)
} else {
fn = b.linearModel(min, max, mid)
}
return func(result interface{}, err error) time.Duration {
switch allStarted := result.(type) {
case bool:
if allStarted {
return time.Duration(b.properties.Binder.WatchdogHeartbeat)
} else {
ret := fn(n)
n = n + 1
//logger.Debugf("retry bindings in %dms", ret.Milliseconds())
return ret
}
default:
return time.Duration(b.properties.Binder.WatchdogHeartbeat)
}
}
}
// logisticModel returns delay function f(n) using logistic model
// Logistic Function https://en.wikipedia.org/wiki/Logistic_function
//
// https://en.wikipedia.org/wiki/Generalised_logistic_function
func (b *SaramaKafkaBinder) logisticModel(min, max, k, n0 float64, y0 time.Duration) func(n int) time.Duration {
// minK is calculated to make sure f(0) < min + y0 (first value is within y0 seconds of min value)
minK := math.Log((max-min)/float64(y0)-1) / n0
k = math.Max(k, minK)
return func(n int) time.Duration {
if n < 0 {
return time.Duration(min)
}
return time.Duration((max-min)/(1+math.Exp(-k*(float64(n)-n0))) + min)
}
}
// logisticModel returns delay function f(n) using linear model
func (b *SaramaKafkaBinder) linearModel(min, max, n0 float64) func(n int) time.Duration {
return func(n int) time.Duration {
return time.Duration((max-min)/n0/2*float64(n) + min)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"errors"
"fmt"
"github.com/IBM/sarama"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"sync"
)
type saramaGroupConsumer struct {
sync.Mutex
topic string
group string
brokers []string
config *bindingConfig
dispatcher *saramaDispatcher
provisioner *saramaTopicProvisioner
started bool
consumer sarama.ConsumerGroup
cancelFunc context.CancelFunc
closed bool
}
func newSaramaGroupConsumer(topic string, group string, addrs []string, config *bindingConfig, provisioner *saramaTopicProvisioner) (*saramaGroupConsumer, error) {
if group == "" {
return nil, ErrorSubTypeBindingInternal.WithMessage("group is required and cannot be empty")
}
order.SortStable(config.consumer.dispatchInterceptors, order.OrderedFirstCompare)
order.SortStable(config.consumer.handlerInterceptors, order.OrderedFirstCompare)
//config.Consumer.Return.Errors = true
return &saramaGroupConsumer{
topic: topic,
group: group,
brokers: addrs,
config: config,
dispatcher: newSaramaDispatcher(config),
provisioner: provisioner,
}, nil
}
func (g *saramaGroupConsumer) Topic() string {
return g.topic
}
func (g *saramaGroupConsumer) Group() string {
return g.group
}
func (g *saramaGroupConsumer) Start(ctx context.Context) (err error) {
g.Lock()
defer g.Unlock()
defer func() {
if err == nil {
g.started = true
}
}()
switch {
case g.closed:
return ErrorStartClosedBinding.WithMessage("cannot re-start a closed consumer [%s]", g.topic)
case g.started:
return nil
}
if ok, e := g.provisioner.topicExists(g.topic); e != nil || !ok {
return NewKafkaError(ErrorCodeIllegalState, fmt.Sprintf(`topic "%s" does not exists`, g.topic))
}
var e error
g.consumer, e = sarama.NewConsumerGroup(g.brokers, g.group, &g.config.sarama)
if e != nil {
err = translateSaramaBindingError(e, "%s", e.Error())
return
}
cancelCtx, cancelFunc := context.WithCancel(ctx)
if g.config.sarama.Consumer.Return.Errors {
go g.monitorGroupErrors(cancelCtx, g.consumer)
}
go g.handleGroup(cancelCtx, g.consumer)
g.cancelFunc = cancelFunc
return
}
func (g *saramaGroupConsumer) Close() error {
g.Lock()
defer g.Unlock()
defer func() {
g.started = false
g.closed = true
}()
if g.cancelFunc != nil {
g.cancelFunc()
g.cancelFunc = nil
}
if g.consumer == nil {
return nil
}
if e := g.consumer.Close(); e != nil {
return NewKafkaError(ErrorCodeIllegalState, "error when closing group consumer: %v", e)
}
return nil
}
func (g *saramaGroupConsumer) Closed() bool {
g.Lock()
defer g.Unlock()
return g.closed
}
func (g *saramaGroupConsumer) AddHandler(handlerFunc MessageHandlerFunc, opts ...DispatchOptions) error {
return g.dispatcher.AddHandler(handlerFunc, &g.config.consumer, opts)
}
// monitorGroupErrors should be run in separate goroutine
func (g *saramaGroupConsumer) monitorGroupErrors(ctx context.Context, group sarama.ConsumerGroup) {
for {
select {
case e, ok := <-group.Errors():
if !ok {
return
}
if errors.Is(e, sarama.ErrClosedConsumerGroup) {
return
}
logger.WithContext(ctx).Warnf("Consumer Group Error: %v", e)
case <-ctx.Done():
return
}
}
}
// handleGroup should be run in separate goroutine
func (g *saramaGroupConsumer) handleGroup(ctx context.Context, group sarama.ConsumerGroup) {
gh := saramaGroupHandler{
owner: g,
dispatcher: g.dispatcher,
}
for {
// `Consume` should be called inside an infinite loop, when a server-side re-balance happens, the consumer session will need to be recreated to get the new claims
if e := group.Consume(ctx, []string{g.topic}, gh); e != nil {
if errors.Is(e, sarama.ErrClosedConsumerGroup) {
return
}
logger.WithContext(ctx).Warnf("Consumer Error: %v", e)
}
}
}
// saramaGroupHandler implements sarama.ConsumerGroupHandler
type saramaGroupHandler struct {
owner *saramaGroupConsumer
dispatcher *saramaDispatcher
}
func (h saramaGroupHandler) Setup(session sarama.ConsumerGroupSession) error {
for topic, parts := range session.Claims() {
logger.WithContext(session.Context()).
Debugf("Consumer joined group [%s] Topic=[%s] Partitions=%v MemberID=[%s]", h.owner.group, topic, parts, session.MemberID())
}
return nil
}
func (h saramaGroupHandler) Cleanup(session sarama.ConsumerGroupSession) error {
for topic, parts := range session.Claims() {
logger.WithContext(session.Context()).
Debugf("Consumer left group [%s] Topic=[%s] Partitions=%v MemberID=[%s]", h.owner.group, topic, parts, session.MemberID())
}
return nil
}
// ConsumeClaim is run in separate goroutine
func (h saramaGroupHandler) ConsumeClaim(session sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error {
for {
select {
case msg, ok := <-claim.Messages():
if !ok {
return nil
}
go h.handleMessage(session.Context(), session, msg)
case <-session.Context().Done():
return nil
}
}
}
// handleMessage intended to run in separate goroutine
func (h saramaGroupHandler) handleMessage(ctx context.Context, session sarama.ConsumerGroupSession, raw *sarama.ConsumerMessage) {
if e := h.dispatcher.Dispatch(ctx, raw, h.owner); e != nil {
logger.WithContext(ctx).Warnf("failed to handle message: %v", e)
// TODO we should consider limit retry count, or let Handler decide whether to retry by specifying a special error type
session.ResetOffset(raw.Topic, raw.Partition, raw.Offset, e.Error())
return
}
session.MarkMessage(raw, "")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"encoding/json"
"errors"
"github.com/IBM/sarama"
"reflect"
"strings"
)
// MessageHandlerFunc is message handling function that conform with following signature:
//
// func (ctx context.Context, [OPTIONAL_INPUT_PARAMS...]) error
//
// Where OPTIONAL_INPUT_PARAMS could contain following components (of which order is not important):
// - PAYLOAD_PARAM < AnyPayloadType >: message payload, where PayloadType could be any type other than interface, function or chan.
// If PayloadType is interface{}, raw []byte will be used
// - HEADERS_PARAM < Headers >: message headers
// - METADATA_PARAM < *MessageMetadata >: message metadata, includes timestamp, keys, partition, etc.
// - MESSAGE_PARAM < *Message >: raw message, where Message.Payload would be PayloadType if PAYLOAD_PARAM is also present, or []byte
//
// For Example:
//
// func Handle(ctx context.Context, payload *MyStruct) error
// func Handle(ctx context.Context, payload *MyStruct, meta *MessageMetadata) error
// func Handle(ctx context.Context, payload map[string]interface{}) error
// func Handle(ctx context.Context, headers Headers, payload *MyStruct) error
// func Handle(ctx context.Context, payload *MyStruct, raw *Message) error
// func Handle(ctx context.Context, raw *Message) error
type MessageHandlerFunc interface{}
type MessageFilterFunc func(ctx context.Context, msg *Message) (shouldHandle bool)
var (
reflectTypeContext = reflect.TypeOf((*context.Context)(nil)).Elem()
reflectTypeHeaders = reflect.TypeOf(Headers{})
reflectTypeMetadata = reflect.TypeOf(&MessageMetadata{})
reflectTypeMessage = reflect.TypeOf(&Message{})
reflectTypeError = reflect.TypeOf((*error)(nil)).Elem()
)
type param struct {
i int
t reflect.Type
}
func (p param) assign(params []reflect.Value, v reflect.Value) error {
if p.i >= len(params) || p.t == nil {
return nil
}
if !v.Type().ConvertibleTo(p.t) {
return ErrorSubTypeIllegalConsumerUsage.WithMessage("failed to prepare parameters for message handler: cannot assign %T to %T", v.String(), p.t.String())
}
params[p.i] = v.Convert(p.t)
return nil
}
type params struct {
count int
payload param
headers param
metadata param
message param
}
type handler struct {
fn reflect.Value
params params
filterFunc MessageFilterFunc
interceptors []ConsumerHandlerInterceptor
}
/**************************
Dispatcher
**************************/
// Dispatcher process MessageContext and dispatch it to registered MessageHandlerFunc.
// This struct is intended for Subscriber or GroupConsumer implementors. It should not be directly used by application.
type Dispatcher struct {
handlers []*handler
Interceptors []ConsumerDispatchInterceptor
Logger MessageLogger
}
func (d *Dispatcher) AddHandler(fn MessageHandlerFunc, opts ...DispatchOptions) error {
if fn == nil {
return nil
}
// apply options
f := reflect.ValueOf(fn)
h := handler{
fn: f,
}
for _, optFn := range opts {
optFn(&h)
}
// parse and validate input params
t := f.Type()
for i := t.NumIn() - 1; i >= 0; i-- {
switch it := t.In(i); {
case it.AssignableTo(reflectTypeContext):
if i != 0 {
return ErrorSubTypeIllegalConsumerUsage.WithMessage("invalid MessageHandlerFunc signature %v, first input param must be context.Context", fn)
}
case it.ConvertibleTo(reflectTypeHeaders):
h.params.headers = param{i, it}
case it.ConvertibleTo(reflectTypeMetadata):
h.params.metadata = param{i, it}
case it.ConvertibleTo(reflectTypeMessage):
h.params.message = param{i, it}
case h.params.payload.t == nil && d.isSupportedMessagePayloadType(it):
h.params.payload = param{i, it}
default:
return ErrorSubTypeIllegalConsumerUsage.WithMessage("invalid MessageHandlerFunc signature %v, unknown input parameters at index %v", fn, i)
}
h.params.count++
}
// parse and validate output params
for i := t.NumOut() - 1; i >= 0; i-- {
switch ot := t.Out(i); {
case ot.ConvertibleTo(reflectTypeError):
if i != t.NumOut()-1 {
return ErrorSubTypeIllegalConsumerUsage.WithMessage("invalid MessageHandlerFunc signature %v, last output param must be error", fn)
}
default:
return ErrorSubTypeIllegalConsumerUsage.WithMessage("invalid MessageHandlerFunc signature %v, unknown output parameters at index %v", fn, i)
}
}
d.handlers = append(d.handlers, &h)
return nil
}
//nolint:contextcheck // context is passed inside msgCtx
func (d *Dispatcher) Dispatch(msgCtx *MessageContext) (err error) {
defer func() {
switch e := recover().(type) {
case error:
err = ErrorSubTypeConsumerGeneral.WithCause(e, "message dispatcher recovered from panic: %v", e)
case string:
err = ErrorSubTypeConsumerGeneral.WithMessage("message dispatcher recovered from panic: %v", e)
}
}()
// invoke Interceptors
for _, interceptor := range d.Interceptors {
msgCtx, err = interceptor.Intercept(msgCtx)
if err != nil {
return ErrorSubTypeConsumerGeneral.WithMessage("consumer dispatch interceptor error: %v", err)
}
}
defer func() {
err = d.finalizeDispatch(msgCtx, err)
}()
// log message
if d.Logger != nil {
d.Logger.LogReceivedMessage(msgCtx.Context, msgCtx.RawMessage)
}
for _, h := range d.handlers {
// apply filters
if h.filterFunc != nil {
if ok := h.filterFunc(msgCtx.Context, &msgCtx.Message); !ok {
continue
}
}
if err = d.dispatch(msgCtx, h); err != nil {
return
}
}
return nil
}
func (d *Dispatcher) dispatch(msgCtx *MessageContext, h *handler) (err error) {
// invoke handler Interceptors.
// note: we need to make a shallow copy of message because we need to decode the payload
ctx, msg := msgCtx.Context, msgCtx.Message
for _, interceptor := range h.interceptors {
ctx, err = interceptor.BeforeHandling(ctx, &msg)
if err != nil {
return ErrorSubTypeConsumerGeneral.WithMessage("consumer handler interceptor error: %v", err)
}
}
defer func() {
for _, interceptor := range h.interceptors {
ctx, err = interceptor.AfterHandling(ctx, &msg, err)
}
}()
// decode payload
if err = d.decodePayload(ctx, h.params.payload.t, &msg); err != nil {
return
}
err = d.invokeHandler(ctx, h, &msg, msgCtx)
return
}
func (d *Dispatcher) finalizeDispatch(msgCtx *MessageContext, err error) error {
for _, interceptor := range d.Interceptors {
switch finalizer := interceptor.(type) {
case ConsumerDispatchFinalizer:
msgCtx, err = finalizer.Finalize(msgCtx, err)
}
}
if err == nil {
return nil
}
switch {
case errors.Is(err, ErrorCategoryKafka):
return err
default:
return NewKafkaError(ErrorSubTypeCodeConsumerGeneral, err.Error(), err)
}
}
/********************
Helpers
********************/
func (d *Dispatcher) decodePayload(_ context.Context, typ reflect.Type, msg *Message) error {
if _, ok := msg.Payload.([]byte); !ok || typ == nil {
return nil
}
contentType := msg.Headers[HeaderContentType]
switch {
case strings.HasPrefix(contentType, "application/json"):
ptr, v := d.instantiateByType(typ)
if e := json.Unmarshal(msg.Payload.([]byte), ptr.Interface()); e != nil {
return ErrorSubTypeDecoding.WithCause(e, "unable to decode as JSON: %v", e)
}
msg.Payload = v.Interface()
case contentType == MIMETypeText:
msg.Payload = string(msg.Payload.([]byte))
case contentType == MIMETypeBinary:
// do nothing
default:
return ErrorSubTypeDecoding.WithMessage("unsupported MIME type %s", contentType)
}
return nil
}
func (d *Dispatcher) invokeHandler(ctx context.Context, handler *handler, msg *Message, msgCtx *MessageContext) (err error) {
// prepare input params
in := make([]reflect.Value, handler.params.count)
in[0] = reflect.ValueOf(ctx)
if e := handler.params.payload.assign(in, reflect.ValueOf(msg.Payload)); e != nil {
return e
}
if e := handler.params.headers.assign(in, reflect.ValueOf(msg.Headers)); e != nil {
return e
}
if e := handler.params.message.assign(in, reflect.ValueOf(msg)); e != nil {
return e
}
// message metadata
if handler.params.metadata.i != 0 {
var meta *MessageMetadata
switch raw := msgCtx.RawMessage.(type) {
case *sarama.ConsumerMessage:
meta = &MessageMetadata{
Key: raw.Key,
Partition: int(raw.Partition),
Offset: int(raw.Offset),
Timestamp: raw.Timestamp,
}
default:
meta = &MessageMetadata{}
}
if e := handler.params.metadata.assign(in, reflect.ValueOf(meta)); e != nil {
return e
}
}
// invoke
out := handler.fn.Call(in)
// post process output
err, _ = out[0].Interface().(error)
return
}
// instantiateByType
// "ptr" is the pointer regardless if given type is Ptr or other type
// "value" is actually the value with given type
func (d *Dispatcher) instantiateByType(t reflect.Type) (ptr reflect.Value, value reflect.Value) {
switch t.Kind() {
case reflect.Ptr:
pp := reflect.New(t)
p, v := d.instantiateByType(t.Elem())
pp.Elem().Set(v.Addr())
return p, pp.Elem()
default:
p := reflect.New(t)
return p, p.Elem()
}
}
func (d *Dispatcher) isSupportedMessagePayloadType(t reflect.Type) bool {
switch t.Kind() {
case reflect.Ptr:
if t.Elem().Kind() == reflect.Ptr {
return false
}
return d.isSupportedMessagePayloadType(t.Elem())
case reflect.Interface, reflect.Func, reflect.Chan:
return false
default:
return true
}
}
/**************************
sarama dispatcher
**************************/
type saramaDispatcher struct {
Dispatcher
}
func newSaramaDispatcher(cfg *bindingConfig) *saramaDispatcher {
return &saramaDispatcher{
Dispatcher{
handlers: []*handler{},
Interceptors: cfg.consumer.dispatchInterceptors,
Logger: cfg.msgLogger,
},
}
}
func (d *saramaDispatcher) Dispatch(ctx context.Context, raw *sarama.ConsumerMessage, source interface{}) (err error) {
// parse header
headers := Headers{}
for _, rh := range raw.Headers {
if rh == nil || len(rh.Key) == 0 || len(rh.Value) == 0 {
continue
}
headers[string(rh.Key)] = string(rh.Value)
}
// create message context
msgCtx := &MessageContext{
Context: ctx,
Message: Message{
Headers: headers,
Payload: raw.Value,
},
Source: source,
Topic: raw.Topic,
RawMessage: raw,
}
return d.Dispatcher.Dispatch(msgCtx)
}
func (d *saramaDispatcher) AddHandler(fn MessageHandlerFunc, cfg *consumerConfig, opts []DispatchOptions) error {
opts = append([]DispatchOptions{AddInterceptors(cfg.handlerInterceptors...)}, opts...)
return d.Dispatcher.AddHandler(fn, opts...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"encoding"
"encoding/json"
"github.com/IBM/sarama"
)
type jsonEncoder struct{}
func (enc jsonEncoder) MIMEType() string {
return MIMETypeJson
}
func (enc jsonEncoder) Encode(v interface{}) (bytes []byte, err error) {
if bytes, err = json.Marshal(v); err != nil {
return bytes, ErrorSubTypeEncoding.WithCause(err, "%s", err.Error())
}
return
}
type binaryEncoder struct{}
func (enc binaryEncoder) MIMEType() string {
return MIMETypeBinary
}
func (enc binaryEncoder) Encode(v interface{}) (bytes []byte, err error) {
switch val := v.(type) {
case string:
return []byte(val), nil
case []byte:
return val, nil
case encoding.BinaryMarshaler:
if bytes, err = val.MarshalBinary(); err != nil {
return bytes, ErrorSubTypeEncoding.WithCause(err, "%s", err.Error())
}
return
default:
return nil, ErrorSubTypeEncoding.WithMessage("unsupported value for binary encoding: %T", v)
}
}
type saramaEncoderWrapper struct {
v interface{}
enc Encoder
cache []byte
}
func newSaramaEncoder(v interface{}, enc Encoder) sarama.Encoder {
if v == nil {
return nil
}
if enc == nil {
enc = binaryEncoder{}
}
return &saramaEncoderWrapper{
v: v,
enc: enc,
}
}
func (w *saramaEncoderWrapper) Encode() (ret []byte, err error) {
if w.cache != nil {
return w.cache, nil
}
defer func() {
w.cache = ret
}()
ret, err = w.enc.Encode(w.v)
return
}
func (w *saramaEncoderWrapper) Length() int {
data, e := w.Encode()
if e != nil {
return 0
}
return len(data)
}
// mimeTypeProducerInterceptor implement ProducerMessageInterceptor.
// This interceptor applies value encoder and set Content-Type to message headers
type mimeTypeProducerInterceptor struct{}
func (i mimeTypeProducerInterceptor) Intercept(msgCtx *MessageContext) (*MessageContext, error) {
if msgCtx.Message.Headers == nil {
msgCtx.Message.Headers = Headers{}
}
if msgCtx.ValueEncoder == nil {
msgCtx.ValueEncoder = jsonEncoder{}
}
msgCtx.Message.Headers[HeaderContentType] = msgCtx.ValueEncoder.MIMEType()
msgCtx.Message.Payload = newSaramaEncoder(msgCtx.Message.Payload, msgCtx.ValueEncoder)
return msgCtx, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"errors"
"fmt"
"github.com/IBM/sarama"
. "github.com/cisco-open/go-lanai/pkg/utils/error"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
)
const (
// Reserved kafka reserved error range
Reserved = 0x1a << errorutils.ReservedOffset
)
// All "Type" values are used as mask
const (
_ = iota
ErrorTypeCodeBinding = Reserved + iota<<errorutils.ErrorTypeOffset
ErrorTypeCodeProducer
ErrorTypeCodeConsumer
)
// All "SubType" values are used as mask
// sub-types of ErrorTypeCodeBinding
const (
_ = iota
ErrorSubTypeCodeBindingInternal = ErrorTypeCodeBinding + iota<<errorutils.ErrorSubTypeOffset
ErrorSubTypeCodeConnectivity
ErrorSubTypeCodeProvisioning
)
// ErrorSubTypeCodeBindingInternal
const (
_ = iota
ErrorCodeBindingInternal = ErrorSubTypeCodeBindingInternal + iota
)
// ErrorSubTypeCodeConnectivity
const (
_ = iota
ErrorCodeBrokerNotReachable = ErrorSubTypeCodeConnectivity + iota
)
// ErrorSubTypeCodeProvisioning
const (
_ = iota
ErrorCodeIllegalState = ErrorSubTypeCodeProvisioning + iota
ErrorCodeProducerExists
ErrorCodeConsumerExists
ErrorCodeAutoCreateTopicFailed
ErrorCodeAutoAddPartitionsFailed
ErrorCodeIllegalLifecycleState
)
// All "SubType" values are used as mask
// sub-types of ErrorTypeProducer
const (
_ = iota
ErrorSubTypeCodeProducerGeneral = ErrorTypeCodeProducer + iota<<errorutils.ErrorSubTypeOffset
ErrorSubTypeCodeIllegalProducerUsage
ErrorSubTypeCodeEncoding
)
// All "SubType" values are used as mask
// sub-types of ErrorTypeConsumer
const (
_ = iota
ErrorSubTypeCodeConsumerGeneral = ErrorTypeCodeConsumer + iota<<errorutils.ErrorSubTypeOffset
ErrorSubTypeCodeIllegalConsumerUsage
ErrorSubTypeCodeDecoding
)
// ErrorTypes, can be used in errors.Is
//
//goland:noinspection GoUnusedGlobalVariable
var (
ErrorCategoryKafka = NewErrorCategory(Reserved, errors.New("error type: kafka"))
ErrorTypeBinding = NewErrorType(ErrorTypeCodeBinding, errors.New("error type: binding"))
ErrorTypeProducer = NewErrorType(ErrorTypeCodeProducer, errors.New("error type: producer"))
ErrorTypeConsumer = NewErrorType(ErrorTypeCodeConsumer, errors.New("error type: consumer"))
ErrorSubTypeBindingInternal = NewErrorSubType(ErrorSubTypeCodeBindingInternal, errors.New("error sub-type: internal"))
ErrorSubTypeConnectivity = NewErrorSubType(ErrorSubTypeCodeConnectivity, errors.New("error sub-type: connectivity"))
ErrorSubTypeProvisioning = NewErrorSubType(ErrorSubTypeCodeProvisioning, errors.New("error sub-type: provisioning"))
ErrorSubTypeProducerGeneral = NewErrorSubType(ErrorSubTypeCodeProducerGeneral, errors.New("error sub-type: producer"))
ErrorSubTypeIllegalProducerUsage = NewErrorSubType(ErrorSubTypeCodeIllegalProducerUsage, errors.New("error sub-type: producer api usage"))
ErrorSubTypeEncoding = NewErrorSubType(ErrorSubTypeCodeEncoding, errors.New("error sub-type: encoding"))
ErrorSubTypeConsumerGeneral = NewErrorSubType(ErrorSubTypeCodeConsumerGeneral, errors.New("error sub-type: consumer"))
ErrorSubTypeIllegalConsumerUsage = NewErrorSubType(ErrorSubTypeCodeIllegalConsumerUsage, errors.New("error sub-type: consumer api usage"))
ErrorSubTypeDecoding = NewErrorSubType(ErrorSubTypeCodeDecoding, errors.New("error sub-type: decoding"))
ErrorStartClosedBinding = NewKafkaError(ErrorCodeIllegalLifecycleState, "error: cannot start closed binding")
)
func init() {
errorutils.Reserve(ErrorCategoryKafka)
}
/************************
Constructors
*************************/
func NewKafkaError(code int64, text string, causes ...interface{}) *CodedError {
return NewCodedError(code, errors.New(text), causes...)
}
func translateSaramaBindingError(cause error, msg string, args ...interface{}) error {
if errors.Is(cause, ErrorCategoryKafka) {
return cause
}
switch cause {
case sarama.ErrOutOfBrokers:
return NewKafkaError(ErrorCodeBrokerNotReachable, fmt.Sprintf(msg, args...), cause)
case sarama.ErrClosedClient, sarama.ErrAlreadyConnected,
sarama.ErrNotConnected, sarama.ErrShuttingDown, sarama.ErrControllerNotAvailable:
return NewKafkaError(ErrorCodeIllegalState, fmt.Sprintf(msg, args...), cause)
case sarama.ErrInvalidPartition, sarama.ErrIncompleteResponse,
sarama.ErrInsufficientData, sarama.ErrMessageTooLarge, sarama.ErrNoTopicsToUpdateMetadata:
return ErrorSubTypeProvisioning.WithCause(cause, msg, args...)
case sarama.ErrConsumerOffsetNotAdvanced:
// note, this should not happen during binding, we use generic internal
fallthrough
default:
return NewKafkaError(ErrorCodeBindingInternal, fmt.Sprintf(msg, args...), cause)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
)
type HealthIndicator struct {
binder SaramaBinder
}
func NewHealthIndicator(binder Binder) *HealthIndicator {
return &HealthIndicator{binder: binder.(SaramaBinder)}
}
func (i *HealthIndicator) Name() string {
return "kafka"
}
func (i *HealthIndicator) Health(_ context.Context, opts health.Options) health.Health {
topics := i.binder.ListTopics()
client := i.binder.Client()
if client == nil {
return health.NewDetailedHealth(health.StatusUnknown, "kafka client not initialized yet", nil)
}
var details map[string]interface{}
if opts.ShowDetails {
details = map[string]interface{}{
"topics": topics,
}
}
if err := client.RefreshMetadata(topics...); err != nil {
return health.NewDetailedHealth(health.StatusDown, "kafka refresh metadata failed", details)
}
return health.NewDetailedHealth(health.StatusUp, "kafka refresh metadata succeeded", details)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"fmt"
"github.com/IBM/sarama"
"github.com/cisco-open/go-lanai/pkg/log"
)
type MessageLogger interface {
WithLevel(level log.LoggingLevel) MessageLogger
LogSentMessage(ctx context.Context, msg interface{})
LogReceivedMessage(ctx context.Context, msg interface{})
}
type LoggerOptions func(opt *loggerOption)
type loggerOption struct {
Name string
Level log.LoggingLevel
}
type saramaMessageLogger struct {
logger log.ContextualLogger
level log.LoggingLevel
}
func newSaramaMessageLogger(opts ...LoggerOptions) *saramaMessageLogger {
opt := loggerOption{
Name: "Kafka.Msg",
Level: log.LevelDebug,
}
for _, fn := range opts {
fn(&opt)
}
return &saramaMessageLogger{
logger: log.New(opt.Name),
level: opt.Level,
}
}
func (l saramaMessageLogger) WithLevel(level log.LoggingLevel) MessageLogger {
return &saramaMessageLogger{
logger: l.logger,
level: level,
}
}
func (l saramaMessageLogger) LogSentMessage(ctx context.Context, msg interface{}) {
switch m := msg.(type) {
case *sarama.ProducerMessage:
logMsg := fmt.Sprintf("[SENT] [%s] Partition[%d] Offset[%d]: Length=%dB",
m.Topic, m.Partition, m.Offset, m.Value.Length())
if m.Key != nil && m.Key.Length() != 0 {
logMsg = logMsg + fmt.Sprintf(" KeyLength=%dB", m.Key.Length())
}
logger.WithContext(ctx).WithLevel(l.level).Printf(logMsg)
}
}
func (l saramaMessageLogger) LogReceivedMessage(ctx context.Context, msg interface{}) {
switch m := msg.(type) {
case *sarama.ConsumerMessage:
logMsg := fmt.Sprintf("[RECV] [%s] Partition[%d] Offset[%d]: Length=%dB",
m.Topic, m.Partition, m.Offset, len(m.Value))
if len(m.Key) != 0 {
logMsg = logMsg + fmt.Sprintf(" Key=%x", m.Key)
}
logger.WithContext(ctx).WithLevel(l.level).Printf(logMsg)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"github.com/IBM/sarama"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"time"
)
func defaultSaramaConfig(_ context.Context, properties *KafkaProperties) (c *sarama.Config, err error) {
c = sarama.NewConfig()
c.Version = sarama.V2_0_0_0
c.ClientID = properties.ClientId
c.Metadata.RefreshFrequency = time.Duration(properties.Metadata.RefreshFrequency)
if properties.Net.Sasl.Enable {
c.Net.SASL.Enable = properties.Net.Sasl.Enable
c.Net.SASL.Handshake = properties.Net.Sasl.Handshake
c.Net.SASL.User = properties.Net.Sasl.User
c.Net.SASL.Password = properties.Net.Sasl.Password
}
return
}
/*************************************
Options for Producer and Consumer
**************************************/
// Note: the return type here have to be unnamed func for compiler to accept as both ProducerOptions and ConsumerOptions
// See https://golang.org/ref/spec#Type_identity
// BindingName is a ProducerOptions or ConsumerOptions that specify the name of the binding.
// This name is used to read BindingProperties from bootstrap.ApplicationConfig
// If not specified, lower case of topic name is used.
// Regardless if name is specified or if corresponding BindingProperties is found,
// any ProducerOptions or ConsumerOptions used at compile time still apply.
// The overriding order is as follows:
//
// BindingProperties with matching name >
// BindingProperties with name "default" >
// ProducerOptions or ConsumerOptions >
// prepared defaults during initialization
func BindingName(name string) func(cfg *bindingConfig) {
return func(config *bindingConfig) {
if name != "" {
config.name = name
}
}
}
// LogLevel is a ProducerOptions or ConsumerOptions that specify log level of Producer, Subscriber or Consumer
func LogLevel(level log.LoggingLevel) func(cfg *bindingConfig) {
return func(config *bindingConfig) {
config.msgLogger = config.msgLogger.WithLevel(level)
}
}
/***********************
Options for producer
************************/
// WithProducerProperties apply options configured via ProducerProperties
func WithProducerProperties(p *ProducerProperties) ProducerOptions {
return func(cfg *bindingConfig) {
if p.AckMode != nil {
switch *p.AckMode {
case AckModeModeAll:
RequireAllAck()(cfg)
case AckModeModeLocal:
RequireLocalAck()(cfg)
case AckModeModeNone:
RequireNoAck()(cfg)
}
}
if p.LogLevel != nil {
LogLevel(*p.LogLevel)(cfg)
}
utils.MustSetIfNotNil(&cfg.sarama.Producer.Timeout, p.AckTimeout)
utils.MustSetIfNotNil(&cfg.sarama.Producer.Retry.Max, p.MaxRetry)
utils.MustSetIfNotNil(&cfg.sarama.Producer.Retry.Backoff, p.Backoff)
utils.MustSetIfNotNil(&cfg.producer.provisioning.autoCreateTopic, p.Provisioning.AutoCreateTopic)
utils.MustSetIfNotNil(&cfg.producer.provisioning.autoAddPartitions, p.Provisioning.AutoAddPartitions)
utils.MustSetIfNotNil(&cfg.producer.provisioning.allowLowerPartitions, p.Provisioning.AllowLowerPartitions)
utils.MustSetIfNotNil(&cfg.producer.provisioning.partitionCount, p.Provisioning.PartitionCount)
utils.MustSetIfNotNil(&cfg.producer.provisioning.replicationFactor, p.Provisioning.ReplicationFactor)
}
}
// KeyEncoder configures Producer with given encoder for serializing message key
func KeyEncoder(enc Encoder) ProducerOptions {
return func(config *bindingConfig) {
config.producer.keyEncoder = enc
}
}
// Partitions configure Producer's topic provisioning, by specifying min partition required
// and their replica number (min.insync.replicas) in case topics are auto-created
func Partitions(partitionCount int, replicationFactor int) ProducerOptions {
return func(config *bindingConfig) {
if partitionCount < 1 {
partitionCount = 1
}
if replicationFactor < 1 {
replicationFactor = 1
}
config.producer.provisioning.partitionCount = int32(partitionCount)
config.producer.provisioning.replicationFactor = int16(replicationFactor)
}
}
// RequireAllAck waits for all in-sync replicas to commit before responding.
// The minimum number of in-sync replicas is configured on the broker via
// the `min.insync.replicas` configuration Key.
func RequireAllAck() ProducerOptions {
return func(config *bindingConfig) {
config.sarama.Producer.RequiredAcks = sarama.WaitForAll
}
}
// RequireLocalAck waits for only the local commit to succeed before responding.
func RequireLocalAck() ProducerOptions {
return func(config *bindingConfig) {
config.sarama.Producer.RequiredAcks = sarama.WaitForLocal
}
}
// RequireNoAck doesn't send any response, the TCP ACK is all you get.
func RequireNoAck() ProducerOptions {
return func(config *bindingConfig) {
config.sarama.Producer.RequiredAcks = sarama.NoResponse
}
}
func AckTimeout(timeout time.Duration) ProducerOptions {
return func(config *bindingConfig) {
config.sarama.Producer.Timeout = timeout
}
}
/***********************
Options for consumer
************************/
// WithConsumerProperties apply options configured via ConsumerProperties
func WithConsumerProperties(p *ConsumerProperties) ConsumerOptions {
return func(cfg *bindingConfig) {
if p.LogLevel != nil {
LogLevel(*p.LogLevel)(cfg)
}
utils.MustSetIfNotNil(&cfg.sarama.Consumer.Retry.Backoff, p.Backoff)
utils.MustSetIfNotNil(&cfg.sarama.Consumer.Group.Rebalance.Timeout, p.Group.JoinTimeout)
utils.MustSetIfNotNil(&cfg.sarama.Consumer.Group.Rebalance.Retry.Max, p.Group.MaxRetry)
utils.MustSetIfNotNil(&cfg.sarama.Consumer.Group.Rebalance.Retry.Backoff, p.Group.Backoff)
}
}
/**********************
Options for message
***********************/
type deliveryMode int
const (
modeSync deliveryMode = iota
)
type messageConfig struct {
ValueEncoder Encoder
Key interface{}
Mode deliveryMode
}
func defaultMessageConfig() messageConfig {
return messageConfig{
ValueEncoder: jsonEncoder{},
//Key: uuid.New(),
Mode: modeSync,
}
}
type MessageOptions func(config *messageConfig)
// WithKey specify key used for the message. The key is typically used for partitioning.
// Supported values depends on the KeyEncoder option on the Producer.
// Default encoder support following types:
// - uuid.UUID
// - string
// - []byte
// - encoding.BinaryMarshaler
func WithKey(key interface{}) MessageOptions {
return func(config *messageConfig) {
config.Key = key
}
}
// WithEncoder specify how message payload is encoded.
// Default is "application/json;application/json;charset=utf-8"
func WithEncoder(valueEncoder Encoder) MessageOptions {
return func(config *messageConfig) {
config.ValueEncoder = valueEncoder
}
}
/*************************
Options for dispatcher
**************************/
type DispatchOptions func(h *handler)
// FilterOnHeader returns a DispatchOptions specifying that
// the handler should be invoked when certain message header exists and matches the provided matcher
func FilterOnHeader(header string, matcher matcher.StringMatcher) DispatchOptions {
if matcher == nil {
return noop()
}
return func(h *handler) {
h.filterFunc = func(ctx context.Context, msg *Message) (shouldHandle bool) {
if msg.Headers == nil {
return false
}
v, ok := msg.Headers[header]
if !ok {
return false
}
if matched, e := matcher.MatchesWithContext(ctx, v); e != nil || !matched {
return false
}
return true
}
}
}
// AddInterceptors returns a DispatchOptions that add ConsumerHandlerInterceptor to a MessageHandlerFunc
func AddInterceptors(interceptors ...ConsumerHandlerInterceptor) DispatchOptions {
return func(h *handler) {
h.interceptors = append(h.interceptors, interceptors...)
}
}
func noop() func(h *handler) {
return func(_ *handler) {
// noop
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
"reflect"
)
var logger = log.New("Kafka")
var Module = &bootstrap.Module{
Precedence: bootstrap.KafkaPrecedence,
Options: []fx.Option{
fx.Provide(BindKafkaProperties, ProvideKafkaBinder),
fx.Provide(tracingProvider()),
fx.Invoke(initialize),
},
}
const (
FxGroup = "kafka"
)
// Use Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
}
type binderDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
Properties KafkaProperties
ProducerInterceptors []ProducerMessageInterceptor `group:"kafka"`
ConsumerInterceptors []ConsumerDispatchInterceptor `group:"kafka"`
HandlerInterceptors []ConsumerHandlerInterceptor `group:"kafka"`
TLSCertsManager certs.Manager `optional:"true"`
}
func ProvideKafkaBinder(di binderDI) Binder {
return NewBinder(di.AppContext, func(opt *BinderOption) {
*opt = BinderOption{
ApplicationConfig: di.AppContext.Config(),
Properties: di.Properties,
ProducerInterceptors: append(opt.ProducerInterceptors, filterZeroValues(di.ProducerInterceptors)...),
ConsumerInterceptors: append(opt.ConsumerInterceptors, filterZeroValues(di.ConsumerInterceptors)...),
HandlerInterceptors: append(opt.HandlerInterceptors, filterZeroValues(di.HandlerInterceptors)...),
TLSCertsManager: di.TLSCertsManager,
}
})
}
type initDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Lifecycle fx.Lifecycle
Properties KafkaProperties
Binder Binder
HealthRegistrar health.Registrar `optional:"true"`
}
func initialize(di initDI) {
// register lifecycle functions
di.Lifecycle.Append(fx.Hook{
OnStart: func(_ context.Context) error {
//nolint:contextcheck // intentional, given context is cancelled after bootstrap, AppCtx is cancelled when app close
return di.Binder.(BinderLifecycle).Start(di.AppCtx)
},
OnStop: func(ctx context.Context) error {
return di.Binder.(BinderLifecycle).Shutdown(ctx)
},
})
// register health endpoints if applicable
if di.HealthRegistrar == nil {
return
}
di.HealthRegistrar.MustRegister(NewHealthIndicator(di.Binder))
}
func filterZeroValues[T any](values []T) []T {
filtered := make([]T, 0, len(values))
for i := range values {
rv := reflect.ValueOf(values[i])
if rv.IsValid() && !rv.IsZero() {
filtered = append(filtered, values[i])
}
}
return filtered
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"errors"
"fmt"
"github.com/IBM/sarama"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"sync"
)
type saramaProducer struct {
sync.RWMutex
topic string
brokers []string
config *bindingConfig
keyEncoder Encoder
msgLogger MessageLogger
interceptors []ProducerMessageInterceptor
syncProducer sarama.SyncProducer
readyCh chan struct{}
closed bool
}
func newSaramaProducer(topic string, addrs []string, config *bindingConfig) (*saramaProducer, error) {
//sync producer must have these two properties set to true
config.sarama.Producer.Return.Successes = true
config.sarama.Producer.Return.Errors = true
config.sarama.Producer.Partitioner = func(topic string) sarama.Partitioner {
return sarama.NewRandomPartitioner(topic)
}
order.SortStable(config.producer.interceptors, order.OrderedFirstCompare)
p := &saramaProducer{
topic: topic,
brokers: addrs,
config: config,
keyEncoder: config.producer.keyEncoder,
msgLogger: config.msgLogger,
interceptors: config.producer.interceptors,
readyCh: make(chan struct{}),
}
return p, nil
}
func (p *saramaProducer) Topic() string {
return p.topic
}
func (p *saramaProducer) Send(ctx context.Context, message interface{}, options ...MessageOptions) (err error) {
var syncProducer sarama.SyncProducer
p.RLock()
syncProducer = p.syncProducer
p.RUnlock()
if syncProducer == nil {
return NewKafkaError(ErrorSubTypeCodeIllegalProducerUsage, fmt.Sprintf(`producer for topic "%s" is not started yet`, p.topic))
}
msgCtx := p.prepare(ctx, message)
if msgCtx.Message.Payload == nil {
return nil
}
// apply options
for _, optionFunc := range options {
optionFunc(&msgCtx.messageConfig)
}
// apply Interceptors
for _, interceptor := range p.interceptors {
if msgCtx, err = interceptor.Intercept(msgCtx); err != nil {
return ErrorSubTypeProducerGeneral.WithMessage("producer interceptor error: %v", err)
}
}
// initialize sarama message
saramaMessage := &sarama.ProducerMessage{
Topic: p.topic,
Headers: p.convertHeaders(msgCtx.Message.Headers),
Value: msgCtx.Message.Payload.(sarama.Encoder),
Key: newSaramaEncoder(msgCtx.Key, p.keyEncoder),
Metadata: msgCtx,
}
msgCtx.RawMessage = saramaMessage
// do send
switch msgCtx.Mode {
case modeSync:
partition, offset, e := syncProducer.SendMessage(saramaMessage)
// apply finalizers
err = p.finalizeSend(msgCtx, partition, offset, e)
default:
err = ErrorSubTypeIllegalProducerUsage.WithMessage("%v Mode is not supported", msgCtx.Mode)
}
return
}
func (p *saramaProducer) ReadyCh() <-chan struct{} {
return p.readyCh
}
func (p *saramaProducer) Start(_ context.Context) error {
p.Lock()
defer p.Unlock()
switch {
case p.closed:
return ErrorStartClosedBinding.WithMessage("cannot re-start a closed producer [%s]", p.topic)
case p.syncProducer != nil:
return nil
}
internal, e := sarama.NewSyncProducer(p.brokers, &p.config.sarama)
if e != nil {
return translateSaramaBindingError(e, "unable to start producer: %v", e)
}
p.syncProducer = internal
close(p.readyCh)
return nil
}
func (p *saramaProducer) Close() error {
p.Lock()
defer p.Unlock()
if p.syncProducer == nil {
return nil
}
if e := p.syncProducer.Close(); e != nil {
return NewKafkaError(ErrorCodeIllegalState, "error when closing producer: %v", e)
}
p.closed = true
p.syncProducer = nil
return nil
}
func (p *saramaProducer) Closed() bool {
p.Lock()
defer p.Unlock()
return p.closed
}
func (p *saramaProducer) prepare(ctx context.Context, v interface{}) *MessageContext {
msgCtx := MessageContext{
Context: ctx,
Topic: p.topic,
messageConfig: defaultMessageConfig(),
Source: p,
}
switch m := v.(type) {
case *Message:
msgCtx.Message = *m
case Message:
msgCtx.Message = m
default:
msgCtx.Message = Message{
Headers: Headers{},
Payload: v,
}
}
if msgCtx.Message.Headers == nil {
msgCtx.Message.Headers = Headers{}
}
return &msgCtx
}
func (p *saramaProducer) finalizeSend(msgCtx *MessageContext, partition int32, offset int64, err error) error {
p.msgLogger.LogSentMessage(msgCtx.Context, msgCtx.RawMessage)
for _, interceptor := range p.interceptors {
switch finalizer := interceptor.(type) {
case ProducerMessageFinalizer:
msgCtx, err = finalizer.Finalize(msgCtx, partition, offset, err)
}
}
if err == nil {
return nil
}
switch {
case errors.Is(err, ErrorCategoryKafka):
return err
default:
return NewKafkaError(ErrorSubTypeCodeProducerGeneral, err.Error(), err)
}
}
func (p *saramaProducer) convertHeaders(headers Headers) (ret []sarama.RecordHeader) {
if headers == nil {
return
}
ret = make([]sarama.RecordHeader, len(headers))
var i int
for k, v := range headers {
ret[i] = sarama.RecordHeader{
Key: []byte(k),
Value: []byte(v),
}
i++
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"strings"
"time"
)
const (
ConfigKafkaPrefix = "kafka"
ConfigKafkaBindingPrefix = "kafka.bindings"
ConfigKafkaDefaultBindingPrefix = "kafka.bindings.default"
)
//goland:noinspection GoNameStartsWithPackageName
type KafkaProperties struct {
Brokers utils.CommaSeparatedSlice `json:"brokers"`
Net Net `json:"net"`
Metadata Metadata `json:"metadata"`
Binder BinderProperties `json:"binder"`
ClientId string `json:"client-id"`
}
type Net struct {
Sasl SASL `json:"sasl"`
Tls TLS `json:"tls"`
}
type Metadata struct {
RefreshFrequency utils.Duration `json:"refresh-frequency"`
}
type SASL struct {
// Whether or not to use SASL authentication when connecting to the broker
// (defaults to false).
Enable bool `json:"enabled"`
// Whether or not to send the Kafka SASL handshake first if enabled
// (defaults to true). You should only set this to false if you're using
// a non-Kafka SASL proxy.
Handshake bool `json:"handshake"`
//username and password for SASL/PLAIN authentication
User string `json:"user"`
Password string `josn:"password"`
}
type TLS struct {
Enable bool `json:"enabled"`
Certs certs.SourceProperties `json:"certs"`
}
type BinderProperties struct {
InitialHeartbeat utils.Duration `json:"init-heartbeat"`
HeartbeatCurveFactor float64 `json:"heartbeat-curve-factor"`
HeartbeatCurveMidpoint float64 `json:"heartbeat-curve-midpoint"`
WatchdogHeartbeat utils.Duration `json:"watchdog-heartbeat"`
}
const (
AckModeModeAll AckMode = "all"
AckModeModeLocal AckMode = "local"
AckModeModeNone AckMode = "none"
)
type AckMode string
func (m *AckMode) UnmarshalText(data []byte) error {
switch strings.ToLower(string(data)) {
case string(AckModeModeAll):
*m = AckModeModeAll
case string(AckModeModeLocal):
*m = AckModeModeLocal
case string(AckModeModeNone):
*m = AckModeModeNone
default:
*m = AckModeModeNone
}
return nil
}
type BindingProperties struct {
Producer ProducerProperties `json:"producer"`
Consumer ConsumerProperties `json:"consumer"`
}
type ProducerProperties struct {
LogLevel *log.LoggingLevel `json:"log-level"`
AckMode *AckMode `json:"ack-mode"`
AckTimeout *utils.Duration `json:"ack-timeout"`
MaxRetry *int `json:"max-retry"`
Backoff *utils.Duration `json:"backoff-interval"`
Provisioning ProvisioningProperties `json:"provisioning"`
}
type ConsumerProperties struct {
LogLevel *log.LoggingLevel `json:"log-level"`
Backoff *utils.Duration `json:"backoff-interval"`
Group ConsumerGroupProperties `json:"group"`
}
type ProvisioningProperties struct {
// AutoCreateTopic when topic doesn't exist, whether attempt to create one
AutoCreateTopic *bool `json:"auto-create-topic"`
// AutoAddPartitions when actual partition counts is less than PartitionCount, whether attempt to add more partitions
AutoAddPartitions *bool `json:"auto-add-partitions"`
// AllowLowerPartitions when actual partition counts is less than PartitionCount but AutoAddPartitions is false,
// whether return an error
AllowLowerPartitions *bool `json:"allow-lower-partitions"`
// PartitionCount number of partitions of given topic
PartitionCount *int32 `json:"partition-count"`
// ReplicationFactor number of replicas per partition when creating topic
ReplicationFactor *int16 `json:"replication-factor"`
}
type ConsumerGroupProperties struct {
JoinTimeout *utils.Duration `json:"join-timeout"`
MaxRetry *int `json:"max-retry"`
Backoff *utils.Duration `json:"backoff-interval"`
}
func BindKafkaProperties(ctx *bootstrap.ApplicationContext) KafkaProperties {
props := KafkaProperties{
Net: Net{
Sasl: SASL{
Enable: false,
Handshake: true,
},
Tls: TLS{
Enable: false,
},
},
Metadata: Metadata{
RefreshFrequency: utils.Duration(5 * time.Minute),
},
Binder: BinderProperties{
InitialHeartbeat: utils.Duration(5 * time.Second),
WatchdogHeartbeat: utils.Duration(120 * time.Second),
HeartbeatCurveFactor: 0.5,
HeartbeatCurveMidpoint: 10, // recommend > 5
},
ClientId: ctx.Name(),
}
if err := ctx.Config().Bind(&props, ConfigKafkaPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind kafka properties"))
}
return props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"fmt"
"github.com/IBM/sarama"
)
type globalClientProviderFunc func() (sarama.Client, error)
type clusterAdminProviderFunc func() (sarama.ClusterAdmin, error)
type saramaTopicProvisioner struct {
globalClient globalClientProviderFunc
adminClient clusterAdminProviderFunc
}
func (p *saramaTopicProvisioner) topicExists(topic string) (bool, error) {
gc, e := p.globalClient()
if e != nil {
return false, e
}
if e := gc.RefreshMetadata(); e != nil {
return false, translateSaramaBindingError(e, "unable to refresh metadata: %v", e)
}
topics, e := gc.Topics()
if e != nil {
return false, translateSaramaBindingError(e, "unable to read topics: %v", e)
}
for _, t := range topics {
if t == topic {
return true, nil
}
}
return false, nil
}
func (p *saramaTopicProvisioner) provisionTopic(topic string, cfg *bindingConfig) error {
exists, e := p.topicExists(topic)
if e != nil {
return e
}
if exists {
return p.tryProvisionPartitions(topic, &cfg.producer.provisioning)
} else {
return p.tryCreateTopic(topic, &cfg.producer.provisioning)
}
}
func (p *saramaTopicProvisioner) tryCreateTopic(topic string, cfg *topicConfig) error {
if !cfg.autoCreateTopic {
return NewKafkaError(ErrorCodeIllegalState, fmt.Sprintf(`kafka topic "%s" doesn't exists, and auto-create is disabled`, topic))
}
topicDetails := &sarama.TopicDetail{
NumPartitions: cfg.partitionCount,
ReplicationFactor: cfg.replicationFactor,
}
ac, e := p.adminClient()
if e != nil {
return e
}
if e := ac.CreateTopic(topic, topicDetails, false); e != nil {
return NewKafkaError(ErrorCodeAutoCreateTopicFailed, fmt.Sprintf(`unable to create topic "%s": %v`, topic, e))
}
return nil
}
func (p *saramaTopicProvisioner) tryProvisionPartitions(topic string, cfg *topicConfig) error {
gc, e := p.globalClient()
if e != nil {
return e
}
parts, e := gc.Partitions(topic)
if e != nil {
return translateSaramaBindingError(e, "unable to read partitions config of topic %s: %v", topic, e)
}
count := len(parts)
switch {
case count >= int(cfg.partitionCount):
return nil
case !cfg.autoAddPartitions && cfg.allowLowerPartitions:
return nil
case !cfg.autoAddPartitions:
return NewKafkaError(ErrorCodeAutoAddPartitionsFailed,
fmt.Sprintf(`topic "%s" has less partitions than required (expected=%d, actual=%d), but auto-add partitions is disabled`,
topic, cfg.partitionCount, count))
}
// we can create partitions
ac, e := p.adminClient()
if e != nil {
return e
}
if e := ac.CreatePartitions(topic, cfg.partitionCount, nil, true); e != nil {
return NewKafkaError(ErrorCodeAutoAddPartitionsFailed, fmt.Sprintf(`unable to add partitions to topic "%s": %v`, topic, e))
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"fmt"
"github.com/IBM/sarama"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"reflect"
"sync"
)
type saramaSubscriber struct {
sync.Mutex
topic string
brokers []string
config *bindingConfig
dispatcher *saramaDispatcher
provisioner *saramaTopicProvisioner
started bool
consumer sarama.Consumer
partitions []int32
cancelFunc context.CancelFunc
closed bool
}
func newSaramaSubscriber(topic string, addrs []string, config *bindingConfig, provisioner *saramaTopicProvisioner) (*saramaSubscriber, error) {
order.SortStable(config.consumer.dispatchInterceptors, order.OrderedFirstCompare)
order.SortStable(config.consumer.handlerInterceptors, order.OrderedFirstCompare)
return &saramaSubscriber{
topic: topic,
brokers: addrs,
config: config,
dispatcher: newSaramaDispatcher(config),
provisioner: provisioner,
}, nil
}
func (s *saramaSubscriber) Topic() string {
return s.topic
}
func (s *saramaSubscriber) Partitions() []int32 {
return s.partitions
}
func (s *saramaSubscriber) Start(ctx context.Context) (err error) {
s.Lock()
defer s.Unlock()
defer func() {
if err == nil {
s.started = true
}
}()
switch {
case s.closed:
return ErrorStartClosedBinding.WithMessage("cannot re-start a closed subscriber [%s]", s.topic)
case s.started:
return nil
}
if ok, e := s.provisioner.topicExists(s.topic); e != nil || !ok {
return NewKafkaError(ErrorCodeIllegalState, fmt.Sprintf(`topic "%s" does not exists`, s.topic))
}
var e error
if s.consumer, e = sarama.NewConsumer(s.brokers, &s.config.sarama); e != nil {
err = translateSaramaBindingError(e, "%s", e.Error())
return
}
if s.partitions, e = s.consumer.Partitions(s.topic); e != nil {
err = translateSaramaBindingError(e, "%s", e.Error())
return
}
partitionConsumers := make([]sarama.PartitionConsumer, len(s.partitions))
for i, p := range s.partitions {
if partitionConsumers[i], e = s.consumer.ConsumePartition(s.topic, p, sarama.OffsetNewest); e != nil {
err = translateSaramaBindingError(e, "%s", e.Error())
return
}
}
cancelCtx, cancelFunc := context.WithCancel(ctx)
go s.handlePartitions(cancelCtx, partitionConsumers)
s.cancelFunc = cancelFunc
return
}
func (s *saramaSubscriber) Close() error {
s.Lock()
defer s.Unlock()
defer func() {
s.started = false
s.closed = true
}()
if s.cancelFunc != nil {
s.cancelFunc()
s.cancelFunc = nil
}
if s.consumer == nil {
return nil
}
if e := s.consumer.Close(); e != nil {
return NewKafkaError(ErrorCodeIllegalState, "error when closing subscriber: %v", e)
}
return nil
}
func (s *saramaSubscriber) Closed() bool {
s.Lock()
defer s.Unlock()
return s.closed
}
func (s *saramaSubscriber) AddHandler(handlerFunc MessageHandlerFunc, opts ...DispatchOptions) error {
return s.dispatcher.AddHandler(handlerFunc, &s.config.consumer, opts)
}
// handlePartitions intended to run in separate goroutine
func (s *saramaSubscriber) handlePartitions(ctx context.Context, partitions []sarama.PartitionConsumer) {
cases := make([]reflect.SelectCase, len(partitions)+1)
cases[0] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ctx.Done()),
}
for i, pc := range partitions {
cases[i+1] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(pc.Messages()),
}
}
for {
chosen, val, ok := reflect.Select(cases)
if !ok || chosen == 0 {
// channel closed or Done channel received
break
}
msg, ok := val.Interface().(*sarama.ConsumerMessage)
if !ok || msg == nil {
logger.WithContext(ctx).Warnf("unrecognized object received from subscriber of partition [%d]: %T", chosen-1, val.Interface())
continue
}
childCtx := utils.MakeMutableContext(ctx)
go s.handleMessage(childCtx, msg) //nolint:contextcheck
}
}
// handleMessage intended to run in separate goroutine
func (s *saramaSubscriber) handleMessage(ctx context.Context, raw *sarama.ConsumerMessage) {
if e := s.dispatcher.Dispatch(ctx, raw, s); e != nil {
logger.WithContext(ctx).Warnf("failed to handle message: %v", e)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"fmt"
"github.com/IBM/sarama"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"go.uber.org/fx"
"math/rand"
"testing"
"time"
)
type kCtxMockedBroker struct{}
// WithMockedBroker test with mocking provided by sarama.MockBroker.
// The approach we take here is using sarama.MockBroker.SetHandlerByMap. Tester typically mock request/response
// via Mock... functions. e.g. MockExistingTopic
// Issue: sarama.MockBroker would block when certain request type is not mocked, which would cause cascade failure
// of subsequent tests. Typical result of missing request/response mock is test hanging until context expires.
// This issue would happen when upgrading sarama to newer version.
// Solution: All requests can be seen in sarama.MockBroker.SetHandlerByMap. Find out the missing mocks and investigate
// the reason. If it's due to added feature or change in sarama package, create new mocks accordingly.
// To create new mocks, add it into MockBroker.defaults() and modify Mock...() functions based on test cases
func WithMockedBroker() test.Options {
//nolint:gosec // Not security related
r := rand.New(rand.NewSource(time.Now().UnixNano()))
cfg := MockedBrokerConfig{
Port: 0x7fff + r.Intn(0x7fff) + 1,
Topics: []string{"test.topic"},
}
var mock *MockBroker
return test.WithOptions(
test.Setup(func(ctx context.Context, t *testing.T) (context.Context, error) {
mock = NewMockedBroker(t, &cfg)
return context.WithValue(ctx, kCtxMockedBroker{}, mock), nil
}),
apptest.WithDynamicProperties(map[string]apptest.PropertyValuerFunc{
"kafka.brokers": func(ctx context.Context) interface{} {
return fmt.Sprintf("localhost:%d", cfg.Port)
},
}),
apptest.WithFxOptions(fx.Provide(func() *MockBroker {
return mock
})),
test.Teardown(func(ctx context.Context, t *testing.T) error {
if mock, ok := ctx.Value(kCtxMockedBroker{}).(*MockBroker); ok {
mock.Close()
}
return nil
}),
)
}
func CurrentMockedBroker(ctx context.Context) *MockBroker {
mock, _ := ctx.Value(kCtxMockedBroker{}).(*MockBroker)
return mock
}
type MockedBrokerConfig struct {
Port int
Topics []string
}
func NewMockedBroker(t *testing.T, cfg *MockedBrokerConfig) *MockBroker {
mock := sarama.NewMockBrokerAddr(t, 0, fmt.Sprintf(`localhost:%d`, cfg.Port))
ret := &MockBroker{
MockBroker: mock,
t: t,
Topics: make(map[string]map[int32]struct{}),
}
ret.Reset()
return ret
}
type MockResponseUpdateFunc func(mr sarama.MockResponse) sarama.MockResponse
func SetOrAppend(newMR sarama.MockResponse) MockResponseUpdateFunc {
return func(mr sarama.MockResponse) sarama.MockResponse {
if mr != nil {
return sarama.NewMockSequence(mr, newMR)
}
return newMR
}
}
type MockBroker struct {
*sarama.MockBroker
t *testing.T
Mocks map[string]sarama.MockResponse
Topics map[string]map[int32]struct{}
}
func (b *MockBroker) Reset() {
b.Mocks = b.defaults()
b.MockBroker.SetHandlerByMap(b.Mocks)
b.Topics = make(map[string]map[int32]struct{})
}
func (b *MockBroker) UpdateMocks(mocks map[string]MockResponseUpdateFunc) {
for k := range mocks {
current, _ := b.Mocks[k]
b.Mocks[k] = mocks[k](current)
}
b.MockBroker.SetHandlerByMap(b.Mocks)
}
func (b *MockBroker) AddTopic(topic string, partition int32, append bool) {
partitions, ok := b.Topics[topic]
if !ok {
partitions = map[int32]struct{}{}
b.Topics[topic] = partitions
}
partitions[partition] = struct{}{}
if append {
b.appendMetadataResponse()
} else {
b.updateMetadataResponse()
}
}
func (b *MockBroker) defaults() map[string]sarama.MockResponse {
return map[string]sarama.MockResponse{
// General
"HeartbeatRequest": sarama.NewMockHeartbeatResponse(b.t),
"MetadataRequest": sarama.NewMockMetadataResponse(b.t).
SetBroker(b.MockBroker.Addr(), b.MockBroker.BrokerID()).
SetController(b.MockBroker.BrokerID()),
"ApiVersionsRequest": sarama.NewMockApiVersionsResponse(b.t),
// For pubsub
"OffsetRequest": sarama.NewMockOffsetResponse(b.t),
"FetchRequest": sarama.NewMockFetchResponse(b.t, 1),
// For group
"OffsetFetchRequest": sarama.NewMockOffsetFetchResponse(b.t),
"OffsetCommitRequest": sarama.NewMockOffsetCommitResponse(b.t),
"FindCoordinatorRequest": sarama.NewMockFindCoordinatorResponse(b.t),
"JoinGroupRequest": sarama.NewMockSequence(
sarama.NewMockJoinGroupResponse(b.t).SetError(sarama.ErrOffsetsLoadInProgress),
sarama.NewMockJoinGroupResponse(b.t).SetGroupProtocol(sarama.RangeBalanceStrategyName),
),
"SyncGroupRequest": sarama.NewMockSyncGroupResponse(b.t).SetError(sarama.ErrOffsetsLoadInProgress),
}
}
func (b *MockBroker) updateMetadataResponse() {
b.UpdateMocks(map[string]MockResponseUpdateFunc{
"MetadataRequest": func(mr sarama.MockResponse) sarama.MockResponse {
var resp *sarama.MockMetadataResponse
switch v := mr.(type) {
case *sarama.MockMetadataResponse:
resp = v
case *sarama.MockSequence:
resp = sarama.NewMockMetadataResponse(b.t).
SetBroker(b.MockBroker.Addr(), b.MockBroker.BrokerID()).
SetController(b.MockBroker.BrokerID())
default:
return mr
}
for topic, partitions := range b.Topics {
for part := range partitions {
resp = resp.SetLeader(topic, part, b.BrokerID())
}
}
return resp
},
})
}
func (b *MockBroker) appendMetadataResponse() {
resp := sarama.NewMockMetadataResponse(b.t).
SetBroker(b.MockBroker.Addr(), b.MockBroker.BrokerID()).
SetController(b.MockBroker.BrokerID())
for topic, partitions := range b.Topics {
for part := range partitions {
resp = resp.SetLeader(topic, part, b.BrokerID())
}
}
b.UpdateMocks(map[string]MockResponseUpdateFunc{
"MetadataRequest": SetOrAppend(resp),
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"github.com/IBM/sarama"
"github.com/cisco-open/go-lanai/pkg/kafka"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/test"
"go.uber.org/fx"
"sync"
"testing"
)
type kCtxHeadersMocker struct{}
type MockHeadersOut struct {
fx.Out
Concrete *MockedHeadersInterceptor
Interface kafka.ConsumerDispatchInterceptor `group:"kafka"`
}
func ProvideMockedHeadersInterceptor() MockHeadersOut {
interceptor := &MockedHeadersInterceptor{
Headers: make(map[string]map[int32]map[int64]kafka.Headers),
}
return MockHeadersOut{
Concrete: interceptor,
Interface: interceptor,
}
}
type MockHeadersDI struct {
fx.In
HeadersMocker *MockedHeadersInterceptor
}
func SubSetupHeadersMocker(di *MockHeadersDI) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
if ctx.Value(kCtxHeadersMocker{}) != nil {
return ctx, nil
}
return context.WithValue(ctx, kCtxHeadersMocker{}, di.HeadersMocker), nil
}
}
func CurrentHeadersMocker(ctx context.Context) *MockedHeadersInterceptor {
if mocker, ok := ctx.Value(kCtxHeadersMocker{}).(*MockedHeadersInterceptor); ok {
return mocker
}
return &MockedHeadersInterceptor{Headers: make(map[string]map[int32]map[int64]kafka.Headers)}
}
type MockedHeadersInterceptor struct {
mtx sync.Mutex
Headers map[string]map[int32]map[int64]kafka.Headers
}
func (i *MockedHeadersInterceptor) Order() int {
return order.Highest
}
func (i *MockedHeadersInterceptor) Intercept(msgCtx *kafka.MessageContext) (*kafka.MessageContext, error) {
i.mtx.Lock()
defer i.mtx.Unlock()
switch raw := msgCtx.RawMessage.(type) {
case *sarama.ConsumerMessage:
partitions, ok := i.Headers[raw.Topic]
if !ok {
break
}
offsets, ok := partitions[raw.Partition]
if !ok {
break
}
headers, ok := offsets[raw.Offset]
if !ok {
break
}
if msgCtx.Message.Headers == nil {
msgCtx.Message.Headers = make(kafka.Headers)
}
for k, v := range headers {
msgCtx.Message.Headers[k] = v
raw.Headers = append(raw.Headers, &sarama.RecordHeader{
Key: []byte(k),
Value: []byte(v),
})
}
}
return msgCtx, nil
}
func (i *MockedHeadersInterceptor) MockHeaders(topic string, partition int32, offset int64, headers kafka.Headers) {
i.mtx.Lock()
defer i.mtx.Unlock()
partitions, ok := i.Headers[topic]
defer func() {i.Headers[topic] = partitions}()
if !ok {
partitions = make(map[int32]map[int64]kafka.Headers)
}
offsets, ok := partitions[partition]
defer func() {partitions[partition] = offsets}()
if !ok {
offsets = make(map[int64]kafka.Headers)
}
offsets[offset] = headers
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"github.com/IBM/sarama"
)
func MockCreateTopic(ctx context.Context, topic string) {
mock := CurrentMockedBroker(ctx)
t := mock.t
updaters := map[string]MockResponseUpdateFunc{
// TODO how to mock error response?
"DescribeAclsRequest": SetOrAppend(sarama.NewMockListAclsResponse(t)),
"CreateTopicsRequest": SetOrAppend(sarama.NewMockCreateTopicsResponse(t)),
}
mock.UpdateMocks(updaters)
mock.AddTopic(topic, 0, true)
}
func MockExistingTopic(ctx context.Context, topic string, partition int32) {
mock := CurrentMockedBroker(ctx)
updaters := map[string]MockResponseUpdateFunc{
"OffsetRequest": func(mr sarama.MockResponse) sarama.MockResponse {
return mr.(*sarama.MockOffsetResponse).
SetOffset(topic, partition, sarama.OffsetOldest, 0).
SetOffset(topic, partition, sarama.OffsetNewest, 0)
},
}
mock.UpdateMocks(updaters)
mock.AddTopic(topic, partition, false)
}
func MockCreatePartition(ctx context.Context, topic string, partition int32) {
mock := CurrentMockedBroker(ctx)
t := mock.t
updaters := map[string]MockResponseUpdateFunc{
// TODO how to mock error response?
"DescribeAclsRequest": SetOrAppend(sarama.NewMockListAclsResponse(t)),
"CreatePartitionsRequest": SetOrAppend(sarama.NewMockCreatePartitionsResponse(t)),
}
mock.UpdateMocks(updaters)
mock.AddTopic(topic, partition, true)
}
func MockProduce(ctx context.Context, topic string, fail bool) {
mock := CurrentMockedBroker(ctx)
t := mock.t
resp := sarama.NewMockProduceResponse(t)
if fail {
resp = resp.SetError(topic, 0, sarama.ErrUnknown)
} else {
resp = resp.SetError(topic, 0, sarama.ErrNoError)
}
updaters := map[string]MockResponseUpdateFunc{
"ProduceRequest": SetOrAppend(resp),
}
mock.UpdateMocks(updaters)
}
func MockGroup(ctx context.Context, topic string, group string, partition int32) {
mock := CurrentMockedBroker(ctx)
updaters := map[string]MockResponseUpdateFunc{
"FindCoordinatorRequest": func(mr sarama.MockResponse) sarama.MockResponse {
return mr.(*sarama.MockFindCoordinatorResponse).
SetCoordinator(sarama.CoordinatorGroup, group, mock.MockBroker)
},
"SyncGroupRequest": SetOrAppend(sarama.NewMockSyncGroupResponse(mock.t).SetMemberAssignment(
&sarama.ConsumerGroupMemberAssignment{
Version: 3,
Topics: map[string][]int32{
topic: {partition},
},
}),
),
"OffsetCommitRequest": func(mr sarama.MockResponse) sarama.MockResponse {
return mr.(*sarama.MockOffsetCommitResponse).SetError(group, topic, partition, sarama.ErrNoError)
},
}
mock.UpdateMocks(updaters)
}
func MockSubscribedMessage(ctx context.Context, topic string, partition int32, offset int64, msg MockedMessage) {
mock := CurrentMockedBroker(ctx)
updaters := map[string]MockResponseUpdateFunc{
"OffsetRequest": func(mr sarama.MockResponse) sarama.MockResponse {
return mr.(*sarama.MockOffsetResponse).
SetOffset(topic, partition, sarama.OffsetOldest, 0).
SetOffset(topic, partition, sarama.OffsetNewest, offset)
},
"FetchRequest": func(mr sarama.MockResponse) sarama.MockResponse {
if len(msg.Key) == 0 {
return mr.(*sarama.MockFetchResponse).SetMessage(topic, partition, offset, sarama.ByteEncoder(msg.Value))
}
return mr.(*sarama.MockFetchResponse).SetMessageWithKey(topic, partition, offset, sarama.ByteEncoder(msg.Key), sarama.ByteEncoder(msg.Value))
},
}
mock.UpdateMocks(updaters)
if len(msg.Headers) != 0 {
// Mock headers separately
// Note, as of sarama 1.38.x, the sarama.MockFetchResponse can only add legacy messages (no header supported).
// See sarama.MockFetchResponse.For(), it uses sarama.FetchResponse.AddMessage instead of sarama.FetchResponse.AddRecord
CurrentHeadersMocker(ctx).MockHeaders(topic, partition, offset, msg.Headers)
}
}
func MockGroupMessage(ctx context.Context, topic, group string, partition int32, offset int64, msg MockedMessage) {
mock := CurrentMockedBroker(ctx)
updaters := map[string]MockResponseUpdateFunc{
"OffsetRequest": func(mr sarama.MockResponse) sarama.MockResponse {
return mr.(*sarama.MockOffsetResponse).
SetOffset(topic, partition, sarama.OffsetOldest, 0).
SetOffset(topic, partition, sarama.OffsetNewest, offset)
},
"FetchRequest": func(mr sarama.MockResponse) sarama.MockResponse {
if len(msg.Key) == 0 {
return mr.(*sarama.MockFetchResponse).SetMessage(topic, partition, offset, sarama.ByteEncoder(msg.Value))
}
return mr.(*sarama.MockFetchResponse).SetMessageWithKey(topic, partition, offset, sarama.ByteEncoder(msg.Key), sarama.ByteEncoder(msg.Value))
},
"OffsetFetchRequest": func(mr sarama.MockResponse) sarama.MockResponse {
return mr.(*sarama.MockOffsetFetchResponse).SetOffset(group, topic, partition, offset, "", sarama.ErrNoError)
},
}
mock.UpdateMocks(updaters)
if len(msg.Headers) != 0 {
// Mock headers separately
// Note, as of sarama 1.38.x, the sarama.MockFetchResponse can only add legacy messages (no header supported).
// See sarama.MockFetchResponse.For(), it uses sarama.FetchResponse.AddMessage instead of sarama.FetchResponse.AddRecord
CurrentHeadersMocker(ctx).MockHeaders(topic, partition, offset, msg.Headers)
}
}
type MockedMessage struct {
Key []byte
Value []byte
Headers map[string]string
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafka
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"go.uber.org/fx"
)
const tracingOpName = "kafka"
type tracingDI struct {
fx.In
Tracer opentracing.Tracer `optional:"true"`
}
func tracingProvider() fx.Annotated {
return fx.Annotated{
Group: FxGroup,
Target: func(di tracingDI) (ProducerMessageInterceptor, ConsumerDispatchInterceptor, ConsumerHandlerInterceptor) {
if di.Tracer != nil {
return newKafkaInterceptors(di.Tracer)
}
return nil, nil, nil
},
}
}
func newKafkaInterceptors(tracer opentracing.Tracer) (ProducerMessageInterceptor, ConsumerDispatchInterceptor, ConsumerHandlerInterceptor) {
return &kafkaProducerInterceptor{
tracer: tracer,
}, &kafkaConsumerInterceptor{
tracer: tracer,
}, &kafkaHandlerInterceptor{
tracer: tracer,
}
}
// kafkaProducerInterceptor implements kafka.ProducerMessageInterceptor and kafka.ProducerMessageFinalizer
type kafkaProducerInterceptor struct {
tracer opentracing.Tracer
}
func (i kafkaProducerInterceptor) Intercept(msgCtx *MessageContext) (*MessageContext, error) {
cmdStr := "send"
name := tracingOpName + " " + cmdStr
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCClientEnum),
tracing.SpanTag("topic", msgCtx.Topic),
tracing.SpanTag("cmd", cmdStr),
i.spanPropagation(msgCtx),
}
if msgCtx.Key != nil {
opts = append(opts, tracing.SpanTag("key", fmt.Sprint(msgCtx.Key)))
}
ctx := tracing.WithTracer(i.tracer).
WithOpName(name).
WithOptions(opts...).
FollowsOrNoSpan(msgCtx.Context)
msgCtx.Context = ctx
return msgCtx, nil
}
func (i kafkaProducerInterceptor) Finalize(msgCtx *MessageContext, p int32, offset int64, err error) (*MessageContext, error) {
op := tracing.WithTracer(i.tracer)
if err != nil {
op = op.WithOptions(tracing.SpanTag("err", err))
} else {
op = op.
WithOptions(tracing.SpanTag("partition", p)).
WithOptions(tracing.SpanTag("offset", offset))
}
msgCtx.Context = op.FinishAndRewind(msgCtx.Context)
return msgCtx, err
}
func (i kafkaProducerInterceptor) spanPropagation(msgCtx *MessageContext) tracing.SpanOption {
return func(span opentracing.Span) {
// we ignore error, since we can't do anything about it
_ = i.tracer.Inject(span.Context(), opentracing.TextMap, opentracing.TextMapCarrier(msgCtx.Message.Headers))
}
}
// kafkaProducerInterceptor implements kafka.ConsumerDispatchInterceptor and kafka.ConsumerDispatchFinalizer
type kafkaConsumerInterceptor struct {
tracer opentracing.Tracer
}
func (i kafkaConsumerInterceptor) Intercept(msgCtx *MessageContext) (*MessageContext, error) {
// first extract span from message
ctx := tracing.WithTracer(i.tracer).
WithStartOptions(i.spanPropagation(msgCtx)).
ForceNewSpan(msgCtx.Context)
// second, start a follower span
cmdStr := "recv"
switch msgCtx.Source.(type) {
case Subscriber:
cmdStr = "subscribe"
case GroupConsumer:
cmdStr = "consume"
}
name := tracingOpName + " " + cmdStr
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCServerEnum),
tracing.SpanTag("topic", msgCtx.Topic),
tracing.SpanTag("cmd", cmdStr),
}
if msgCtx.Key != nil {
opts = append(opts, tracing.SpanTag("key", fmt.Sprint(msgCtx.Key)))
}
ctx = tracing.WithTracer(i.tracer).
WithOpName(name).
WithOptions(opts...).
NewSpanOrFollows(ctx)
msgCtx.Context = ctx
return msgCtx, nil
}
func (i kafkaConsumerInterceptor) Finalize(msgCtx *MessageContext, err error) (*MessageContext, error) {
op := tracing.WithTracer(i.tracer)
if err != nil {
op = op.WithOptions(tracing.SpanTag("err", err))
}
msgCtx.Context = op.FinishAndRewind(msgCtx.Context)
return msgCtx, err
}
func (i kafkaConsumerInterceptor) spanPropagation(msgCtx *MessageContext) opentracing.StartSpanOption {
spanCtx, e := i.tracer.Extract(opentracing.TextMap, opentracing.TextMapCarrier(msgCtx.Message.Headers))
if e != nil {
return noopStartSpanOption{}
}
return ext.RPCServerOption(spanCtx)
}
// kafkaProducerInterceptor implements kafka.ConsumerHandlerInterceptor
type kafkaHandlerInterceptor struct {
tracer opentracing.Tracer
}
func (i kafkaHandlerInterceptor) BeforeHandling(ctx context.Context, _ *Message) (context.Context, error) {
cmdStr := "handle"
name := tracingOpName + " " + cmdStr
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCServerEnum),
tracing.SpanTag("cmd", cmdStr),
}
ctx = tracing.WithTracer(i.tracer).
WithOpName(name).
WithOptions(opts...).
DescendantOrNoSpan(ctx)
return ctx, nil
}
func (i kafkaHandlerInterceptor) AfterHandling(ctx context.Context, _ *Message, err error) (context.Context, error) {
op := tracing.WithTracer(i.tracer)
if err != nil {
op = op.WithOptions(tracing.SpanTag("err", err))
}
ctx = op.FinishAndRewind(ctx)
return ctx, err
}
type noopStartSpanOption struct{}
func (o noopStartSpanOption) Apply(*opentracing.StartSpanOptions){}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import (
"io"
"strings"
)
// writerAdapter implements io.Writer and wrap around our Logger interface
type writerAdapter struct {
logger Logger
}
func NewWriterAdapter(logger Logger, lvl LoggingLevel) io.Writer {
return &writerAdapter{
logger: logger.WithCaller(RuntimeCaller(5)).WithLevel(lvl),
}
}
func (w writerAdapter) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
w.logger.Print(strings.TrimSpace(string(p)))
return len(p), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log/internal"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// configurableZapLogger implements Logger and Contextual
type configurableZapLogger struct {
zapLogger
leveler zap.AtomicLevel
lvl LoggingLevel
valuers ContextValuers
}
func newConfigurableZapLogger(name string, core zapcore.Core, logLevel LoggingLevel, leveler zap.AtomicLevel, valuers ContextValuers) *configurableZapLogger {
l := &configurableZapLogger{
zapLogger: zapLogger{
core: core,
name: name,
clock: zapcore.DefaultClock,
stacktracer: RuntimeCaller(4),
fields: nil,
},
leveler: leveler,
valuers: valuers,
}
l.setMinLevel(logLevel)
return l
}
func (l *configurableZapLogger) WithContext(ctx context.Context) Logger {
if ctx == nil {
return l
}
fields := make([]interface{}, 0, len(l.valuers)*2)
for k, ctxValuer := range l.valuers {
fields = append(fields, k, ctxValuer(ctx))
}
return l.withKV(fields)
}
func (l *configurableZapLogger) setMinLevel(lv LoggingLevel) {
switch lv {
case LevelOff:
l.leveler.SetLevel(zapcore.InvalidLevel)
case LevelDebug:
l.leveler.SetLevel(zapcore.DebugLevel)
case LevelInfo:
l.leveler.SetLevel(zapcore.InfoLevel)
case LevelWarn:
l.leveler.SetLevel(zapcore.WarnLevel)
case LevelError:
l.leveler.SetLevel(zapcore.ErrorLevel)
default:
l.leveler.SetLevel(zapcore.InfoLevel)
}
l.lvl = lv
}
// IsTerminal implements internal.TerminalAware
func (l *configurableZapLogger) IsTerminal() bool {
termAware, ok := l.core.(internal.TerminalAware)
return ok && termAware.IsTerminal()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import "strings"
/*********************
LoggingLevel
*********************/
type LoggingLevel int
const (
LevelOff LoggingLevel = iota
LevelDebug
LevelInfo
LevelWarn
LevelError
)
const (
LevelOffText = "OFF"
LevelDebugText = "DEBUG"
LevelInfoText = "INFO"
LevelWarnText = "WARN"
LevelErrorText = "ERROR"
)
var (
loggingLevelAtoI = map[string]LoggingLevel{
strings.ToUpper(LevelOffText): LevelOff,
strings.ToUpper(LevelDebugText): LevelDebug,
strings.ToUpper(LevelInfoText): LevelInfo,
strings.ToUpper(LevelWarnText): LevelWarn,
strings.ToUpper(LevelErrorText): LevelError,
}
loggingLevelItoA = map[LoggingLevel]string{
LevelOff: LevelOffText,
LevelDebug: LevelDebugText,
LevelInfo: LevelInfoText,
LevelWarn: LevelWarnText,
LevelError: LevelErrorText,
}
)
// String implements fmt.Stringer
func (l LoggingLevel) String() string {
if s, ok := loggingLevelItoA[l]; ok {
return s
}
return LevelErrorText
}
// MarshalText implements encoding.TextMarshaler
func (l LoggingLevel) MarshalText() ([]byte, error) {
return []byte(l.String()), nil
}
// UnmarshalText implements encoding.TextUnmarshaler
func (l *LoggingLevel) UnmarshalText(data []byte) error {
value := strings.ToUpper(string(data))
if v, ok := loggingLevelAtoI[value]; ok {
*l = v
}
return nil
}
/*********************
Format
*********************/
type Format int
const (
_ Format = iota
FormatText
FormatJson
)
const (
FormatJsonText = "json"
FormatTextText = "text"
)
var (
formatAtoI = map[string]Format{
FormatJsonText: FormatJson,
FormatTextText: FormatText,
}
formatItoA = map[Format]string{
FormatJson: FormatJsonText,
FormatText: FormatTextText,
}
)
// fmt.Stringer
func (f Format) String() string {
if s, ok := formatItoA[f]; ok {
return s
}
return "unknown"
}
// encoding.TextMarshaler
func (f Format) MarshalText() ([]byte, error) {
return []byte(f.String()), nil
}
// encoding.TextUnmarshaler
func (f *Format) UnmarshalText(data []byte) error {
value := strings.ToLower(string(data))
if v, ok := formatAtoI[value]; ok {
*f = v
}
return nil
}
/*********************
LoggerType
*********************/
type LoggerType int
const (
_ LoggerType = iota
TypeConsole
TypeFile
TypeHttp
TypeMQ
)
const (
TypeConsoleText = "console"
TypeFileText = "file"
TypeHttpText = "http"
TypeMQText = "mq"
)
var (
typeAtoI = map[string]LoggerType{
TypeConsoleText: TypeConsole,
TypeFileText: TypeFile,
TypeHttpText: TypeHttp,
TypeMQText: TypeMQ,
}
typeItoA = map[LoggerType]string{
TypeConsole: TypeConsoleText,
TypeFile: TypeFileText,
TypeHttp: TypeHttpText,
TypeMQ: TypeMQText,
}
)
// fmt.Stringer
func (t LoggerType) String() string {
if s, ok := typeItoA[t]; ok {
return s
}
return "unknown"
}
// encoding.TextMarshaler
func (t LoggerType) MarshalText() ([]byte, error) {
return []byte(t.String()), nil
}
// encoding.TextUnmarshaler
func (t *LoggerType) UnmarshalText(data []byte) error {
value := strings.ToLower(string(data))
if v, ok := typeAtoI[value]; ok {
*t = v
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"os"
"path/filepath"
)
/*
Common functions that useful to any logger factory
*/
const (
keyLevelDefault = "default"
keySeparator = "."
nameLevelDefault = "ROOT"
)
func loggerKey(name string) string {
return utils.CamelToSnakeCase(name)
}
func convertLevelsNameToKey(byNames map[string]LoggingLevel) (byKeys map[string]LoggingLevel) {
byKeys = map[string]LoggingLevel{}
for k, v := range byNames {
byKeys[loggerKey(k)] = v
}
return
}
func openOrCreateFile(location string) (*os.File, error) {
if location == "" {
return nil, fmt.Errorf("location is missing for file logger")
}
dir := filepath.Dir(location)
if e := os.MkdirAll(dir, 0744); e != nil {
return nil, e
}
return os.OpenFile(location, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
}
func buildContextValuerFromConfig(properties *Properties) ContextValuers {
valuers := ContextValuers{}
// k is context-key, v is log-key
for k, v := range properties.Mappings {
valuers[v] = func(ctx context.Context) interface{} {
return ctx.Value(k)
}
}
return valuers
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log/internal"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"os"
"strings"
"time"
)
var zapEncoderConfig = zapcore.EncoderConfig{
TimeKey: LogKeyTimestamp,
LevelKey: LogKeyLevel,
NameKey: LogKeyName,
CallerKey: LogKeyCaller,
FunctionKey: zapcore.OmitKey,
MessageKey: LogKeyMessage,
StacktraceKey: LogKeyStacktrace,
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.LowercaseLevelEncoder,
EncodeTime: func(time time.Time, encoder zapcore.PrimitiveArrayEncoder) {
// RFC339 with milliseconds
encoder.AppendString(time.UTC().Format(`2006-01-02T15:04:05.999Z07:00`))
},
EncodeDuration: zapcore.SecondsDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
type zapCoreCreator func(level zap.AtomicLevel) zapcore.Core
type zapLoggerFactory struct {
rootLogLevel LoggingLevel
logLevels map[string]LoggingLevel
coreCreator zapCoreCreator
properties *Properties
effectiveValuers ContextValuers
extraValuers ContextValuers
registry map[string]*configurableZapLogger
}
func newZapLoggerFactory(properties *Properties) *zapLoggerFactory {
rootLogLevel, ok := properties.Levels[keyLevelDefault]
if !ok {
rootLogLevel = LevelInfo
}
var e error
f := &zapLoggerFactory{
rootLogLevel: rootLogLevel,
logLevels: convertLevelsNameToKey(properties.Levels),
properties: properties,
registry: map[string]*configurableZapLogger{},
extraValuers: ContextValuers{},
effectiveValuers: ContextValuers{},
}
f.effectiveValuers = f.buildContextValuer(properties)
if f.coreCreator, e = f.buildZapCoreCreator(properties); e != nil {
panic(e)
}
return f
}
func (f *zapLoggerFactory) createLogger(name string) ContextualLogger {
key := loggerKey(name)
if l, ok := f.registry[key]; ok {
return l
}
ll := f.resolveEffectiveLevel(key)
leveler := zap.NewAtomicLevel()
l := newConfigurableZapLogger(name, f.coreCreator(leveler), ll, leveler, f.effectiveValuers)
f.registry[key] = l
return l
}
func (f *zapLoggerFactory) addContextValuers(valuers ...ContextValuers) {
for _, item := range valuers {
for k, v := range item {
f.effectiveValuers[k] = v
f.extraValuers[k] = v
}
}
}
func (f *zapLoggerFactory) setRootLevel(logLevel LoggingLevel) (affected int) {
f.rootLogLevel = logLevel
for k, l := range f.registry {
effective := f.resolveEffectiveLevel(k)
l.setMinLevel(effective)
affected++
}
return
}
func (f *zapLoggerFactory) setLevel(prefix string, logLevel *LoggingLevel) (affected int) {
key := loggerKey(prefix)
if (key == "" || key == keyLevelDefault || key == loggerKey(nameLevelDefault)) && logLevel != nil {
return f.setRootLevel(*logLevel)
}
if logLevel == nil {
// unset
if _, ok := f.logLevels[key]; ok {
delete(f.logLevels, key)
}
} else {
// set
f.logLevels[key] = *logLevel
}
// set effective level to all affected loggers
withDot := key + keySeparator
for k, l := range f.registry {
if k != key && !strings.HasPrefix(k, withDot) {
continue
}
effective := f.resolveEffectiveLevel(k)
l.setMinLevel(effective)
affected++
}
return
}
func (f *zapLoggerFactory) refresh(properties *Properties) error {
rootLogLevel, ok := properties.Levels[keyLevelDefault]
if !ok {
rootLogLevel = LevelInfo
}
f.rootLogLevel = rootLogLevel
f.logLevels = convertLevelsNameToKey(properties.Levels)
f.effectiveValuers = buildContextValuerFromConfig(properties)
var e error
if f.coreCreator, e = f.buildZapCoreCreator(properties); e != nil {
return e
}
// merge valuers, note: we don't delete extra valuers during refresh
for k, v := range f.extraValuers {
f.effectiveValuers[k] = v
}
for key, l := range f.registry {
ll := f.resolveEffectiveLevel(key)
l.core = f.coreCreator(l.leveler)
l.valuers = f.effectiveValuers
l.setMinLevel(ll)
}
return nil
}
func (f *zapLoggerFactory) resolveEffectiveLevel(key string) LoggingLevel {
prefix := key
for i := len(key); i > 0; i = strings.LastIndex(prefix, keySeparator) {
prefix = key[0:i]
if ll, ok := f.logLevels[prefix]; ok {
return ll
}
}
return f.rootLogLevel
}
func (f *zapLoggerFactory) buildContextValuer(properties *Properties) ContextValuers {
valuers := ContextValuers{}
// k is context-key, v is log-key
for k, v := range properties.Mappings {
valuers[v] = func(ctx context.Context) interface{} {
return ctx.Value(k)
}
}
return valuers
}
func (f *zapLoggerFactory) buildZapCoreCreator(properties *Properties) (zapCoreCreator, error) {
if len(properties.Loggers) == 0 {
properties.Loggers = map[string]*LoggerProperties{
"default": {
Type: TypeConsole,
Format: FormatText,
Template: defaultTemplate,
FixedKeys: defaultFixedFields.Values(),
},
}
}
encoders := make([]zapcore.Encoder, len(properties.Loggers))
syncers := make([]zapcore.WriteSyncer, len(properties.Loggers))
var i int
for _, loggerProps := range properties.Loggers {
var e error
if syncers[i], e = f.newZapWriteSyncer(loggerProps); e != nil {
return nil, e
}
if encoders[i], e = f.newZapEncoder(loggerProps, syncers[i].(internal.TerminalAware).IsTerminal()); e != nil {
return nil, e
}
i++
}
return func(level zap.AtomicLevel) zapcore.Core {
var core zapcore.Core
var isTerm bool
switch len(encoders) {
case 0:
// not possible
return zapcore.NewNopCore()
case 1:
core = zapcore.NewCore(encoders[0], syncers[0], level)
isTerm = syncers[0].(internal.TerminalAware).IsTerminal()
default:
cores := make([]zapcore.Core, len(encoders))
for i := range encoders {
cores[i] = zapcore.NewCore(encoders[i], syncers[i], level)
isTerm = isTerm && syncers[i].(internal.TerminalAware).IsTerminal()
}
core = zapcore.NewTee(cores...)
}
if isTerm {
return internal.ZapTerminalCore{Core: core}
}
return core
}, nil
}
func (f *zapLoggerFactory) newZapEncoder(props *LoggerProperties, isTerm bool) (zapcore.Encoder, error) {
switch props.Format {
case FormatText:
fixedFields := defaultFixedFields.Copy().Add(props.FixedKeys...)
formatter := internal.NewTemplatedFormatter(props.Template, fixedFields, isTerm)
return internal.NewZapFormattedEncoder(zapEncoderConfig, formatter, isTerm), nil
case FormatJson:
return zapcore.NewJSONEncoder(zapEncoderConfig), nil
}
return nil, fmt.Errorf("unsupported logger format: %v", props.Format)
}
func (f *zapLoggerFactory) newZapWriteSyncer(props *LoggerProperties) (zapcore.WriteSyncer, error) {
switch props.Type {
case TypeConsole:
return internal.NewZapWriterWrapper(os.Stdout), nil
case TypeFile:
file, e := openOrCreateFile(props.Location)
if e != nil {
return nil, e
}
return internal.NewZapWriterWrapper(file), nil
case TypeHttp:
fallthrough
case TypeMQ:
fallthrough
default:
return nil, fmt.Errorf("unsupported logger type: %v", props.Type)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"bytes"
"github.com/cisco-open/go-lanai/pkg/utils"
"go.uber.org/zap/buffer"
"io"
"strings"
"text/template"
)
const (
LogKeyMessage = "msg"
LogKeyName = "logger"
LogKeyTimestamp = "time"
LogKeyCaller = "caller"
LogKeyLevel = "level"
LogKeyContext = "ctx"
LogKeyStacktrace = "stacktrace"
)
const (
logTemplate = "lanai-log-template"
)
type Fields map[string]interface{}
type TextFormatter interface {
Format(kvs Fields, w io.Writer) error
}
type TemplatedFormatter struct {
text string
tmpl *template.Template
fixedFields utils.StringSet
isTerm bool
}
func NewTemplatedFormatter(tmpl string, fixedFields utils.StringSet, isTerm bool) *TemplatedFormatter {
formatter := &TemplatedFormatter{
text: tmpl,
fixedFields: fixedFields,
isTerm: isTerm,
}
formatter.init()
return formatter
}
func (f *TemplatedFormatter) init() {
if !strings.HasSuffix(f.text, "\n") {
f.text = f.text + "\n"
}
funcMap := TmplFuncMapNonTerm
colorFuncMap := TmplColorFuncMapNonTerm
if f.isTerm {
colorFuncMap = TmplFuncMap
funcMap = TmplColorFuncMap
}
t, e := template.New(logTemplate).
Option("missingkey=zero").
Funcs(funcMap).
Funcs(colorFuncMap).
Funcs(template.FuncMap{
"kv": MakeKVFunc(f.fixedFields),
}).
Parse(f.text)
if e != nil {
panic(e)
}
f.tmpl = t
}
func (f *TemplatedFormatter) Format(kvs Fields, w io.Writer) error {
switch w.(type) {
case *buffer.Buffer:
return f.tmpl.Execute(w, kvs)
default:
// from documents of template.Template.Execute:
// A template may be executed safely in parallel, although if parallel
// executions share a Writer the output may be interleaved.
// to prevent this from happening, we use an in-memory buffer. Hopefully this is faster than mutex locking
var buf bytes.Buffer
if e := f.tmpl.Execute(&buf, kvs); e != nil {
return e
}
if _, e := w.Write(buf.Bytes()); e != nil {
return e
}
return nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"fmt"
"golang.org/x/term"
"io"
"text/template"
)
// color code generation and terminal check is adopted from github.com/go-kit/log/term
// ColorNames names can be used for generic color function.
// quick foreground color function with same name is also availalbe in template
var ColorNames = map[string]Color{
"black": Black,
"red": Red,
"green": Green,
"yellow": Yellow,
"blue": Blue,
"magenta": Magenta,
"cyan": Cyan,
"gray": Gray,
"black_b": BoldBlack,
"red_b": BoldRed,
"green_b": BoldGreen,
"yellow_b": BoldYellow,
"blue_b": BoldBlue,
"magenta_b": BoldMagenta,
"cyan_b": BoldCyan,
"gray_b": BoldGray,
}
var (
TmplColorFuncMap template.FuncMap
TmplColorFuncMapNonTerm template.FuncMap
)
type Color uint8
const (
Default = Color(iota)
Black
Red
Green
Yellow
Blue
Magenta
Cyan
Gray
BoldBlack
BoldRed
BoldGreen
BoldYellow
BoldBlue
BoldMagenta
BoldCyan
BoldGray
numColors
)
var (
FgColors []string
BgColors []string
ResetColor = "\x1b[39;49;22m"
)
// Implementations adopted from github.com/go-kit/log/term
func init() {
// Default
//FgColors = append(FgColors, "\x1b[39m")
//BgColors = append(BgColors, "\x1b[49m")
FgColors = append(FgColors, "")
BgColors = append(BgColors, "")
// dark colors
for color := Black; color < BoldBlack; color++ {
FgColors = append(FgColors, fmt.Sprintf("\x1b[%dm", 30+color-Black))
BgColors = append(BgColors, fmt.Sprintf("\x1b[%dm", 40+color-Black))
}
// bright colors
for color := BoldBlack; color < numColors; color++ {
FgColors = append(FgColors, fmt.Sprintf("\x1b[%d;1m", 30+color-BoldBlack))
BgColors = append(BgColors, fmt.Sprintf("\x1b[%d;1m", 40+color-BoldBlack))
}
// prepare quick color function Map
TmplColorFuncMap = template.FuncMap{"color": Colored}
TmplColorFuncMapNonTerm = template.FuncMap{"color":NoopColored}
for k, v := range ColorNames {
TmplColorFuncMap[k] = MakeQuickColorFunc(v)
TmplColorFuncMapNonTerm[k] = NoopQuickColor
}
}
// fder matches os.File.Fd()
type fder interface {
Fd() uintptr
}
// IsTerminal returns true if w writes to a terminal.
// Implementations adopted from github.com/go-kit/log/term
func IsTerminal(w io.Writer) bool {
if v, ok := w.(fder); ok {
return term.IsTerminal(int(v.Fd()))
}
return false
}
func ColoredWithCode(s interface{}, fg, bg Color) string {
var fgStr, bgStr string
if fg < numColors {
fgStr = FgColors[fg]
}
if bg < numColors {
bgStr = BgColors[bg]
}
return fgStr + bgStr + Sprint(s) + ResetColor
}
func ColoredWithName(s interface{}, fgName, bgName string) string {
fg, _ := ColorNames[fgName]
bg, _ := ColorNames[bgName]
return ColoredWithCode(s, fg, bg)
}
// Colored takes 0, 1, or 2 color names
// when present, they should be in order of fgName, bgName
func Colored(s interface{}, colorNames...string) string {
switch len(colorNames) {
case 0:
return Sprint(s)
case 1:
return ColoredWithName(s, colorNames[0], "")
default:
return ColoredWithName(s, colorNames[0], colorNames[1])
}
}
func MakeQuickColorFunc(fg Color) func(s interface{}) string {
return func(s interface{}) string {
return ColoredWithCode(s, fg, Default)
}
}
func NoopColored(v interface{}, _...string) interface{} {
return v
}
func NoopQuickColor(v interface{}) interface{} {
return v
}
func DebugShowcase() {
loop := func(colorFunc func(s string, c Color) string) {
count := 0
for k, v := range ColorNames {
name := fmt.Sprintf("%-12s", k)
fmt.Printf("%s ", colorFunc(name, v))
count ++
if count % 4 == 0 {
fmt.Println()
}
}
}
fgColorFunc := func(s string, c Color) string {
return ColoredWithCode(s, c, Default)
}
bgColorFunc := func(s string, c Color) string {
return ColoredWithCode(s, Default, c)
}
// run
loop(fgColorFunc)
loop(bgColorFunc)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"encoding"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"math"
"reflect"
"strconv"
"strings"
"text/template"
)
// Note: https://pkg.go.dev/text/template#hdr-Pipelines chainable argument should be the last parameter of any function
var (
TmplFuncMap = template.FuncMap{
"cap": Capped,
"pad": Padding,
"lvl": MakeLevelFunc(true),
"join": Join,
"trace": Trace,
}
TmplFuncMapNonTerm = template.FuncMap{
"cap": Capped,
"pad": Padding,
"lvl": MakeLevelFunc(false),
"join": Join,
"trace": Trace,
}
)
type levelFuncs struct {
text func(int) string
color func(interface{}) string
}
var (
levelFuncsMap = map[string]levelFuncs{
"debug": {
text: MakeLevelPaddingFunc("DEBUG"),
color: MakeQuickColorFunc(Gray),
},
"info": {
text: MakeLevelPaddingFunc("INFO"),
color: MakeQuickColorFunc(Cyan),
},
"warn": {
text: MakeLevelPaddingFunc("WARN"),
color: MakeQuickColorFunc(BoldYellow),
},
"error": {
text: MakeLevelPaddingFunc("ERROR"),
color: MakeQuickColorFunc(BoldRed),
},
}
)
func MakeKVFunc(ignored utils.StringSet) func(Fields) string {
return func(kvs Fields) string {
kvStrs := make([]string, 0, len(kvs))
for k, v := range kvs {
if v == nil || ignored.Has(k) || reflect.ValueOf(v).IsZero() {
continue
}
kvStrs = append(kvStrs, fmt.Sprintf(`%s="%v"`, k, v))
}
if len(kvStrs) == 0 {
return ""
}
return "{" + strings.Join(kvStrs, ", ") + "}"
}
}
func MakeLevelFunc(term bool) func(padding int, kvs Fields) string {
if term {
return func(padding int, kvs Fields) string {
lv, _ := kvs[LogKeyLevel]
lvStr := Sprint(lv)
if funcs, ok := levelFuncsMap[lvStr]; ok {
return funcs.color(funcs.text(padding))
}
return lvStr
}
} else {
return func(padding int, kvs Fields) string {
lv, _ := kvs[LogKeyLevel]
lvStr := Sprint(lv)
if funcs, ok := levelFuncsMap[lvStr]; ok {
return funcs.text(padding)
}
return lvStr
}
}
}
func MakeLevelPaddingFunc(v interface{}) func(int) string {
return func(p int) string {
return Padding(p, v)
}
}
// Padding example: `{{padding -6 value}}` "{{padding 10 value}}"
func Padding(padding int, v interface{}) string {
tag := "%" + strconv.Itoa(padding) + "v"
return fmt.Sprintf(tag, v)
}
// Capped truncate given value to specified length
// if cap > 0: with tailing "..." if truncated
// if cap < 0: with middle "..." if truncated
func Capped(cap int, v interface{}) string {
c := int(math.Abs(float64(cap)))
s := Sprint(v)
if len(s) <= c {
return s
}
if cap > 0 {
return fmt.Sprintf("%." + strconv.Itoa(c - 3) + "s...", s)
} else if cap < 0 {
lead := (c - 3) / 2
tail := c - lead - 3
return fmt.Sprintf("%." + strconv.Itoa(lead) + "s...%s", s, s[len(s)-tail:])
} else {
return ""
}
}
func Join(sep string, values ...interface{}) string {
strs := make([]string, 0, len(values))
for _, v := range values {
s := Sprint(v)
if s != "" {
strs = append(strs, s)
}
}
str := strings.Join(strs, sep)
return str
}
// Trace generate shortest possible tracing info string:
// - if trace ID is not available, return empty string
// - if span ID is same as trace ID, we assume parent ID is 0 and only returns traceID
// - if span ID is different from trace ID and parent ID is same as trace ID, we only returns trace ID and span ID
func Trace(tid, sid, pid interface{}) string {
tidStr, sidStr, pidStr := Sprint(tid), Sprint(sid), Sprint(pid)
switch {
case tidStr == "":
return ""
case sidStr == tidStr:
return tidStr
case pidStr == tidStr:
return tidStr + " " + sidStr
default:
return tidStr + " " + sidStr + " " + pidStr
}
}
func Sprint(v interface{}) string {
switch v.(type) {
case nil:
return ""
case string:
return v.(string)
case []byte:
return string(v.([]byte))
case fmt.Stringer:
return v.(fmt.Stringer).String()
case encoding.TextMarshaler:
if s, e := v.(encoding.TextMarshaler).MarshalText(); e == nil {
return string(s)
}
}
return fmt.Sprintf("%v", v)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"go.uber.org/zap/buffer"
"go.uber.org/zap/zapcore"
"time"
)
var zapBufferPool = buffer.NewPool()
// ZapFormattedEncoder implements zapcore.Encoder. This encoder leverage go template system for render user defined log.
// Note: Unlike zapcore's JSONEncoder and ConsoleEncoder, this encoder focus on flexibility rather than performance.
// When performance is crucial, JSON format of log should be used.
type ZapFormattedEncoder struct {
*zapcore.MapObjectEncoder
Formatter TextFormatter
Config *zapcore.EncoderConfig
IsTerminal bool
}
func NewZapFormattedEncoder(cfg zapcore.EncoderConfig, formatter TextFormatter, isTerm bool) zapcore.Encoder {
return &ZapFormattedEncoder{
MapObjectEncoder: zapcore.NewMapObjectEncoder(),
Formatter: formatter,
Config: &cfg,
IsTerminal: isTerm,
}
}
func (enc *ZapFormattedEncoder) Clone() zapcore.Encoder {
return &ZapFormattedEncoder{
MapObjectEncoder: zapcore.NewMapObjectEncoder(),
Formatter: enc.Formatter,
Config: enc.Config,
IsTerminal: enc.IsTerminal,
}
}
// EncodeEntry implements zapcore.Encoder
// We use map and slice based encoders. Working with map and slice is necessary with go template based formatter.
// Map and slice operations is not the most performant approach, but it wouldn't be the bottleneck comparing to go template rendering
func (enc *ZapFormattedEncoder) EncodeEntry(entry zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) {
objEnc := zapcore.NewMapObjectEncoder()
// encode entry
arrEnc := &SliceArrayEncoder{make([]interface{}, 0, 5)}
objEnc.Fields[enc.Config.NameKey] = applyZapValueEncoder(arrEnc, entry.LoggerName, enc.Config.EncodeName)
objEnc.Fields[enc.Config.LevelKey] = applyZapValueEncoder(arrEnc, entry.Level, enc.Config.EncodeLevel)
objEnc.Fields[enc.Config.TimeKey] = applyZapValueEncoder(arrEnc, entry.Time, enc.Config.EncodeTime)
objEnc.Fields[enc.Config.MessageKey] = entry.Message
if entry.Caller.Defined {
objEnc.Fields[enc.Config.CallerKey] = applyZapValueEncoder(arrEnc, entry.Caller, enc.Config.EncodeCaller)
}
if len(entry.Stack) != 0 {
objEnc.Fields[enc.Config.StacktraceKey] = entry.Stack
}
// encode fields
for i := range fields {
fields[i].AddTo(objEnc)
}
buf := zapBufferPool.Get()
if e := enc.Formatter.Format(objEnc.Fields, buf); e != nil {
return nil, e
}
return buf, nil
}
func applyZapValueEncoder[T any](arrEnc *SliceArrayEncoder, value T, valueEncoder func(T, zapcore.PrimitiveArrayEncoder)) interface{} {
if valueEncoder == nil {
return value
}
valueEncoder(value, arrEnc)
return arrEnc.Latest()
}
// SliceArrayEncoder implementing zapcore.PrimitiveArrayEncoder. It's used to apply zapcore's entry encoders like zapcore.NameEncoder
type SliceArrayEncoder struct {
elems []interface{}
}
func (s *SliceArrayEncoder) Latest() interface{} {
return s.elems[len(s.elems)-1]
}
func (s *SliceArrayEncoder) AppendBool(v bool) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendByteString(v []byte) { s.elems = append(s.elems, string(v)) }
func (s *SliceArrayEncoder) AppendComplex128(v complex128) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendComplex64(v complex64) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendDuration(v time.Duration) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendFloat64(v float64) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendFloat32(v float32) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendInt(v int) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendInt64(v int64) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendInt32(v int32) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendInt16(v int16) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendInt8(v int8) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendString(v string) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendTime(v time.Time) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendUint(v uint) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendUint64(v uint64) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendUint32(v uint32) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendUint16(v uint16) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendUint8(v uint8) { s.elems = append(s.elems, v) }
func (s *SliceArrayEncoder) AppendUintptr(v uintptr) { s.elems = append(s.elems, v) }
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"go.uber.org/zap/zapcore"
"io"
)
type TerminalAware interface {
IsTerminal() bool
}
// ZapWriterWrapper implements zapcore.WriteSyncer and TerminalAware
type ZapWriterWrapper struct {
io.Writer
}
func (ZapWriterWrapper) Sync() error {
return nil
}
func (s ZapWriterWrapper) IsTerminal() bool {
return IsTerminal(s.Writer)
}
// NewZapWriterWrapper similar to zapcore.AddSync with exported type
func NewZapWriterWrapper(w io.Writer) zapcore.WriteSyncer {
return ZapWriterWrapper{
Writer: w,
}
}
// ZapTerminalCore implements TerminalAware and always returns true
type ZapTerminalCore struct {
zapcore.Core
}
func (s ZapTerminalCore) IsTerminal() bool {
return true
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import (
"fmt"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"runtime"
)
/************************
level enabler
************************/
type zapLevel LoggingLevel
func (lvl zapLevel) Enabled(zapLvl zapcore.Level) bool {
switch LoggingLevel(lvl) {
case LevelDebug:
return zapLvl >= zapcore.DebugLevel
case LevelInfo:
return zapLvl >= zapcore.InfoLevel
case LevelWarn:
return zapLvl >= zapcore.WarnLevel
case LevelError:
return zapLvl >= zapcore.ErrorLevel
case LevelOff:
return false
default:
return zapLvl >= 0
}
}
/************************
logger
************************/
// zapLogger implements Logger
type zapLogger struct {
core zapcore.Core
name string
clock zapcore.Clock
stacktracer Stacktracer
fields []zapcore.Field
//zap.SugaredLogger
}
func (l *zapLogger) WithKV(keyvals ...interface{}) Logger {
return l.withKV(keyvals)
}
func (l *zapLogger) WithLevel(lvl LoggingLevel) Logger {
var leveled zapcore.Core
var e error
if leveled, e = zapcore.NewIncreaseLevelCore(l.core, zapLevel(lvl)); e != nil {
// probably trying to decrease level, use noop
leveled = zapcore.NewNopCore()
}
cpy := l.shallowCopy()
cpy.core = leveled
return cpy
}
// WithCaller implements CallerValuer
func (l *zapLogger) WithCaller(caller interface{}) Logger {
cpy := l.shallowCopy()
switch fn := caller.(type) {
case nil:
cpy.stacktracer = nil
case func() ([]*runtime.Frame, interface{}):
cpy.stacktracer = fn
case Stacktracer:
cpy.stacktracer = fn
case func() interface{}:
cpy.stacktracer = func() (frames []*runtime.Frame, fallback interface{}) {
return nil, fn()
}
default:
cpy.stacktracer = func() (frames []*runtime.Frame, fallback interface{}) {
return nil, caller
}
}
return cpy
}
func (l *zapLogger) Debugf(msg string, args ...interface{}) {
l.log(zapcore.DebugLevel, msg, args, nil)
}
func (l *zapLogger) Infof(msg string, args ...interface{}) {
l.log(zapcore.InfoLevel, msg, args, nil)
}
func (l *zapLogger) Warnf(msg string, args ...interface{}) {
l.log(zapcore.WarnLevel, msg, args, nil)
}
func (l *zapLogger) Errorf(msg string, args ...interface{}) {
l.log(zapcore.ErrorLevel, msg, args, nil)
}
func (l *zapLogger) Debug(msg string, keyvals ...interface{}) {
l.log(zapcore.DebugLevel, msg, nil, keyvals)
}
func (l *zapLogger) Info(msg string, keyvals ...interface{}) {
l.log(zapcore.InfoLevel, msg, nil, keyvals)
}
func (l *zapLogger) Warn(msg string, keyvals ...interface{}) {
l.log(zapcore.WarnLevel, msg, nil, keyvals)
}
func (l *zapLogger) Error(msg string, keyvals ...interface{}) {
l.log(zapcore.ErrorLevel, msg, nil, keyvals)
}
func (l *zapLogger) Print(args ...interface{}) {
l.log(zapcore.LevelOf(l.core), "", args, nil)
}
func (l *zapLogger) Printf(format string, args ...interface{}) {
l.log(zapcore.LevelOf(l.core), format, args, nil)
}
func (l *zapLogger) Println(args ...interface{}) {
l.log(zapcore.LevelOf(l.core), "\n", args, nil)
}
func (l *zapLogger) Log(keyvals ...interface{}) error {
l.log(zapcore.LevelOf(l.core), "", nil, keyvals)
return nil
}
func (l *zapLogger) shallowCopy() *zapLogger {
cpy := *l
return &cpy
}
func (l *zapLogger) withKV(keyvals []interface{}) Logger {
cpy := l.shallowCopy()
cpy.fields = append(l.fields, l.toFields(keyvals)...)
return cpy
}
func (l *zapLogger) log(lvl zapcore.Level, msgTmpl string, fmtArgs []interface{}, keyvals []interface{}) {
// If logging at this level is completely disabled, skip the overhead of string formatting.
if lvl < zapcore.DPanicLevel && !l.core.Enabled(lvl) {
return
}
msg := l.constructMessage(msgTmpl, fmtArgs)
ce := l.core.Check(zapcore.Entry{
LoggerName: l.name,
Time: l.clock.Now(),
Level: lvl,
Message: msg,
}, nil)
if ce == nil {
return
}
// process fields
adhocFields := l.toFields(keyvals)
// caller and stacktrace
if l.stacktracer != nil {
switch frames, fallback := l.stacktracer(); {
case len(frames) != 0:
ce.Caller = zapcore.EntryCaller{
Defined: frames[0].PC != 0,
PC: frames[0].PC,
File: frames[0].File,
Line: frames[0].Line,
Function: frames[0].Function,
}
//Note: we currently don't support stacktrace. Here is the place to add it
case fallback != nil:
adhocFields = append(adhocFields, zap.Any(LogKeyCaller, fallback))
default:
// no caller info
}
}
// write log
ce.Write(append(l.fields, adhocFields...)...)
}
// constructMessage similar to SugarLogger's getMessage(...)
func (l *zapLogger) constructMessage(tmpl string, fmtArgs []interface{}) string {
switch {
case len(fmtArgs) == 0:
return tmpl
case tmpl == "\n":
return fmt.Sprintln(fmtArgs...)
case len(tmpl) != 0:
return fmt.Sprintf(tmpl, fmtArgs...)
default:
return fmt.Sprint(fmtArgs...)
}
}
func (l *zapLogger) toFields(keyvals []interface{}) []zapcore.Field {
if len(keyvals) == 0 {
return nil
}
// give it enough space to avoid re-allocate space
fields := make([]zapcore.Field, 0, len(keyvals)/2+2)
for i := 0; i < len(keyvals); i += 2 {
var key string
switch k := keyvals[i].(type) {
case string:
key = k
case fmt.Stringer:
key = k.String()
default:
key = fmt.Sprint(k)
}
if i == len(keyvals)-1 {
fields = append(fields, zap.String(key, "!(MISSING)"))
break
}
switch v := keyvals[i+1].(type) {
case zapcore.Field:
fields = append(fields, v)
case *zapcore.Field:
fields = append(fields, *v)
case func() interface{}:
fields = append(fields, zap.Any(key, v()))
default:
fields = append(fields, zap.Any(key, v))
}
}
return fields
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import (
"dario.cat/mergo"
"strings"
)
// LevelConfig is a read-only carrier struct that stores LoggingLevel configuration of each logger
type LevelConfig struct {
Name string
EffectiveLevel *LoggingLevel
ConfiguredLevel *LoggingLevel
}
// SetLevel set/unset logging level of all loggers with given prefix
// function returns actual number of affected loggers
func SetLevel(prefix string, logLevel *LoggingLevel) int {
return factory.setLevel(prefix, logLevel)
}
// Levels logger level configuration, the returned map's key is the lower case of logger's name
func Levels(prefix string) (ret map[string]*LevelConfig) {
ret = map[string]*LevelConfig{}
prefixKey := loggerKey(prefix)
// collect level config names
for k, v := range factory.registry {
if !strings.HasPrefix(k, prefixKey) {
continue
}
var p string
for i := len(v.name); i > 0; i = strings.LastIndex(p, keySeparator) {
p = v.name[0:i]
ret[loggerKey(p)] = &LevelConfig{Name: p}
}
}
// populate result
for k, v := range ret {
if l, ok := factory.registry[k]; ok {
v.EffectiveLevel = &l.lvl
} else {
lvl := factory.resolveEffectiveLevel(k)
v.EffectiveLevel = &lvl
}
if ll, ok := factory.logLevels[k]; ok {
v.ConfiguredLevel = &ll
}
}
if prefix == "" {
ret[keyLevelDefault] = &LevelConfig{
Name: nameLevelDefault,
EffectiveLevel: &factory.rootLogLevel,
ConfiguredLevel: &factory.rootLogLevel,
}
}
return
}
func UpdateLoggingConfiguration(properties *Properties) error {
mergedProperties := &Properties{}
mergeOption := func(mergoConfig *mergo.Config) {
mergoConfig.Overwrite = true
}
err := mergo.Merge(mergedProperties, defaultConfig, mergeOption)
if err != nil {
return err
}
err = mergo.Merge(mergedProperties, properties, mergeOption)
if err != nil {
return err
}
return factory.refresh(mergedProperties)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import (
"embed"
"encoding/json"
"github.com/ghodss/yaml"
"io"
"io/fs"
"os"
"path"
"reflect"
"strings"
)
// factory is created by init, and used to create new loggers.
var (
factory *zapLoggerFactory
defaultConfig *Properties
)
// New is the intuitive starting point for any packages to use log package
// it will create a named logger if a logger with this name doesn't exist yet
func New(name string) ContextualLogger {
return factory.createLogger(name)
}
func RegisterContextLogFields(extractors ...ContextValuers) {
factory.addContextValuers(extractors...)
}
//go:embed defaults-log.yml
var defaultConfigFS embed.FS
// Since log package cannot depend on other packages in case those package want to use log,
// we have to duplicate the code for reading profile here.
func init() {
fullPath := path.Join("configs", "log.yml")
var err error
if stat, e := os.Stat(fullPath); e == nil && !stat.IsDir() {
// log.yml is available, try use it
defaultConfig, err = loadConfig(os.DirFS("."), fullPath)
}
if err != nil || defaultConfig == nil {
// log.yml is not available, uses embedded defaults
defaultConfig, err = loadConfig(defaultConfigFS, "defaults-log.yml")
}
if err != nil || defaultConfig == nil {
defaultConfig = newProperties()
}
//factory = newKitLoggerFactory(defaultConfig)
factory = newZapLoggerFactory(defaultConfig)
// a test run for dev
//DebugShowcase()
}
func loadConfig(fs fs.FS, path string) (*Properties, error) {
file, e := fs.Open(path)
if e != nil {
return nil, e
}
encoded, e := io.ReadAll(file)
if e != nil {
return nil, e
}
encodedJson, e := yaml.YAMLToJSON(encoded)
if e != nil {
return nil, e
}
props := newProperties()
if e := json.Unmarshal(encodedJson, props); e != nil {
return nil, e
}
normalizeProperties(props)
return props, nil
}
// normalizeProperties updates all KVs to lower case, which is consistent with appconfig binding
func normalizeProperties(props *Properties) {
val := reflect.ValueOf(props).Elem()
for i := val.Type().NumField() - 1; i >= 0; i-- {
fv := val.Field(i)
fv.Set(normalizeMapKeys(fv))
}
}
// normalizeMapKeys updates all keys to lower case, which is consistent with appconfig binding
func normalizeMapKeys(mapValue reflect.Value) reflect.Value {
typ := mapValue.Type()
if typ.Kind() != reflect.Map || typ.Key().Kind() != reflect.String {
return mapValue
}
ret := reflect.MakeMap(typ)
iter := mapValue.MapRange()
for iter.Next() {
k := iter.Key()
str := strings.ToLower(k.String())
v := normalizeMapKeys(iter.Value())
ret.SetMapIndex(reflect.ValueOf(str), v)
}
return ret
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import "github.com/cisco-open/go-lanai/pkg/utils"
const (
defaultTemplate = `{{pad -25 .time}} {{lvl 5 .}} [{{cap -20 .caller | pad 20 | blue}}] {{cap -12 .logger | pad 12 | green}}: {{.msg}} {{kv .}}`
)
var defaultFixedFields = utils.NewStringSet(
LogKeyMessage,
LogKeyName,
LogKeyTimestamp,
LogKeyCaller,
LogKeyLevel,
LogKeyContext,
)
// Properties contains logging settings
// Note:
// 1. "context-mappings" indicate how to map context key to log key, it's map[context-key]log-key
type Properties struct {
Levels map[string]LoggingLevel `json:"levels"`
Loggers map[string]*LoggerProperties `json:"loggers"`
Mappings map[string]string `json:"context-mappings"`
}
// LoggerProperties individual logger setup
// Note:
// 1. we currently only support file and console type
// 2. "location" is ignored when "type" is "console"
// 3. "template" and "fixed-keys" are ignored when "format" is not "text"
// 4. "template" is "text/template" compliant template, with "." as log KVs and following added functions:
// - "{{padding .key -10}}" fixed length stringer
// - "{{level . 5}}" colored level string with fixed length
// - "{{coler .key}}" color code (red, green, yellow, gray, cyan) with pipeline support.
// e.g. "{{padding .msg 20 | red}}"
type LoggerProperties struct {
Type LoggerType `json:"type"`
Format Format `json:"format"`
Location string `json:"location"`
Template string `json:"template"`
FixedKeys utils.CommaSeparatedSlice `json:"fixed-keys"`
}
func newProperties() *Properties {
return &Properties{
Levels: map[string]LoggingLevel{
"default": LevelInfo,
},
Loggers: map[string]*LoggerProperties{
"console": {
Type: TypeConsole,
Format: FormatText,
Template: defaultTemplate,
FixedKeys: utils.CommaSeparatedSlice{
LogKeyName, LogKeyMessage, LogKeyTimestamp,
LogKeyCaller, LogKeyLevel, LogKeyContext,
},
},
},
Mappings: map[string]string{},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import "runtime"
type Stacktracer func() (frames []*runtime.Frame, fallback interface{})
// RuntimeStacktracer find stacktrace frames with runtime package
// skip: skip certain number of stacks from the top, including the call to the Stacktracer itself
// depth: max number of stack frames to extract
func RuntimeStacktracer(skip int, depth int) Stacktracer {
return func() (frames []*runtime.Frame, fallback interface{}) {
rpc := make([]uintptr, depth)
count := runtime.Callers(skip, rpc)
rawFrames := runtime.CallersFrames(rpc)
frames = make([]*runtime.Frame, count)
for i := 0; i < count; i++ {
frame, more := rawFrames.Next()
frames[i] = &frame
if !more {
break
}
}
return frames, nil
}
}
// RuntimeCaller equivalent to RuntimeStacktracer(skip, 1)
func RuntimeCaller(skip int) Stacktracer {
return RuntimeStacktracer(skip, 1)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package log
import (
"github.com/cisco-open/go-lanai/pkg/log/internal"
)
func IsTerminal(l Logger) bool {
v, ok := l.(internal.TerminalAware)
return ok && v.IsTerminal()
}
// Capped truncate given value to specified length
// if cap > 0: with tailing "..." if truncated
// if cap < 0: with middle "..." if truncated
func Capped(v interface{}, cap int) string {
return internal.Capped(cap, v)
}
// Padding example: `Padding("some string", -20)`
func Padding(v interface{}, padding int) string {
return internal.Padding(padding, v)
}
func DebugShowcase() {
internal.DebugShowcase()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package migration
import (
"github.com/cisco-open/go-lanai/pkg/data"
"go.uber.org/fx"
"gorm.io/gorm"
)
type gormConfigurer struct {}
func DefaultGormConfigurerProvider() fx.Annotated {
return fx.Annotated{
Group: data.GormConfigurerGroup,
Target: newGormMigrationConfigurer,
}
}
func newGormMigrationConfigurer() data.GormConfigurer {
return &gormConfigurer{}
}
func (c gormConfigurer) Order() int {
return 0
}
func (c gormConfigurer) Configure(config *gorm.Config) {
config.DisableForeignKeyConstraintWhenMigrating = true
config.FullSaveAssociations = false
config.SkipDefaultTransaction = true
config.CreateBatchSize = 1000
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package migration
import (
"context"
"gorm.io/gorm"
"time"
)
type MigrationVersion struct {
Version Version `gorm:"primaryKey"`
Description string
ExecutionTime time.Duration
InstalledOn time.Time
Success bool
}
func (v MigrationVersion) GetVersion() Version {
return v.Version
}
func (v MigrationVersion) GetDescription() string {
return v.Description
}
func (v MigrationVersion) IsSuccess() bool {
return v.Success
}
func (v MigrationVersion) GetInstalledOn() time.Time {
return v.InstalledOn
}
type GormVersioner struct {
db *gorm.DB
}
func NewGormVersioner(db *gorm.DB) Versioner {
return &GormVersioner{
db: db,
}
}
func (v *GormVersioner) CreateVersionTableIfNotExist(ctx context.Context) error {
return v.db.WithContext(ctx).AutoMigrate(&MigrationVersion{})
}
func (v *GormVersioner) GetAppliedMigrations(ctx context.Context) ([]AppliedMigration, error) {
versions := []MigrationVersion{}
result := v.db.WithContext(ctx).Find(&versions)
if result.Error != nil {
return nil, result.Error
}
retVersions := []AppliedMigration{}
for _, ver := range versions {
retVersions = append(retVersions, ver)
}
return retVersions, nil
}
func (v *GormVersioner) RecordAppliedMigration(ctx context.Context, version Version, description string, success bool, installedOn time.Time, executionTime time.Duration) error {
applied := &MigrationVersion{
Version: version,
Description: description,
Success: success,
InstalledOn: installedOn,
ExecutionTime: executionTime,
}
result := v.db.WithContext(ctx).Save(applied)
return result.Error
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package migration
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils"
"gorm.io/gorm"
"io/fs"
)
type MigrationFunc func(ctx context.Context) error
type Registrar struct {
migrationSteps []*Migration
versions utils.StringSet
}
func NewRegistrar() *Registrar {
return &Registrar{}
}
func (r *Registrar) AddMigrations(m... *Migration) {
r.migrationSteps = append(r.migrationSteps, m...)
}
type Migration struct {
Version Version
Description string
Func MigrationFunc
Tags utils.StringSet
}
func WithVersion(version string) *Migration {
v, err := fromString(version)
if err != nil {
panic(err)
}
return &Migration{
Version: v,
}
}
func (m *Migration) Dot(i int) *Migration {
m.Version = append(m.Version, i)
return m
}
func (m *Migration) WithTag(tags...string) *Migration {
if m.Tags == nil {
m.Tags = utils.NewStringSet(tags...)
} else {
m.Tags.Add(tags...)
}
return m
}
func (m *Migration) WithFile(fs fs.FS, filePath string, db *gorm.DB) *Migration {
m.Func = migrationFuncFromTextFile(fs, filePath, db)
return m
}
func (m *Migration) WithFunc(f MigrationFunc) *Migration {
m.Func = f
return m
}
func (m *Migration) WithDesc(d string) *Migration {
m.Description = d
return m
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package migration
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"sort"
"time"
)
func Migrate(ctx context.Context, r *Registrar, v Versioner) error {
err := v.CreateVersionTableIfNotExist(ctx)
if err != nil {
return err
}
//sort registered migration steps
sort.SliceStable(r.migrationSteps, func(i, j int) bool {return r.migrationSteps[i].Version.Lt(r.migrationSteps[j].Version)})
appliedMigrations, err := v.GetAppliedMigrations(ctx)
if err != nil {
return err
}
//sort applied migration steps
sort.SliceStable(appliedMigrations, func (i, j int) bool {return appliedMigrations[i].GetVersion().Lt(appliedMigrations[j].GetVersion())})
for _, a := range appliedMigrations {
if !a.IsSuccess() {
return errors.New(fmt.Sprintf("stopping migration because there is a failed migration step: %s", a.GetVersion().String()))
}
}
var shouldExecuteMigration func(*Migration) bool
if allowOutOfOrderFlag {
appliedVersions := utils.NewStringSet()
for _, a := range appliedMigrations {
appliedVersions.Add(a.GetVersion().String())
}
shouldExecuteMigration = func(m *Migration) bool {
return !appliedVersions.Has(m.Version.String())
}
} else {
var lastAppliedMigration AppliedMigration
if len(appliedMigrations) > 0 {
lastAppliedMigration = appliedMigrations[len(appliedMigrations)-1]
}
shouldExecuteMigration = func(m *Migration) bool {
return lastAppliedMigration == nil || lastAppliedMigration.GetVersion().Lt(m.Version)
}
}
for _, s := range r.migrationSteps {
if filterFlag != "" && !s.Tags.Has(filterFlag) {
continue
}
//TODO: should the migration func and recording the version be put in one transaction?
if shouldExecuteMigration(s) {
logger.Infof("Executing migration step %s: %s", s.Version.String(), s.Description)
startTime := time.Now()
migrationErr := s.Func(ctx) //TODO: manual rollback function?
finishTime := time.Now()
duration := finishTime.Sub(startTime)
if migrationErr != nil {
err = v.RecordAppliedMigration(ctx, s.Version, s.Description, false, finishTime, duration)
if err != nil {
logger.Errorf("error recording failed migration version due to %v", err)
}
err = errors.New(fmt.Sprintf("migration stopped at step %v because of error: %v", s.Version, migrationErr))
logger.Errorf("%v", err)
return err
} else {
err = v.RecordAppliedMigration(ctx, s.Version, s.Description, true, finishTime, duration)
if err != nil {
return err
}
}
}
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package migration
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"go.uber.org/fx"
"gorm.io/gorm"
)
const (
TagPreUpgrade = "pre_upgrade"
TagPostUpgrade = "post_upgrade"
)
var logger = log.New("Migration")
var filterFlag string
var allowOutOfOrderFlag bool
var Module = &bootstrap.Module{
Name: "migration",
Precedence: bootstrap.CommandLineRunnerPrecedence,
Options: []fx.Option{
fx.Provide(NewRegistrar),
fx.Provide(NewGormVersioner),
fx.Provide(provideMigrationRunner()),
},
}
func Use() {
bootstrap.AddStringFlag(&filterFlag, "filter", "", fmt.Sprintf("filter the migration steps by tag value. supports %s or %s", TagPreUpgrade, TagPostUpgrade))
bootstrap.AddBoolFlag(&allowOutOfOrderFlag, "allow_out_of_order", false, fmt.Sprintf("allow migration steps to execute out of order"))
bootstrap.Register(Module)
// Note: migration CliRunner is provided in Module
bootstrap.EnableCliRunnerMode()
}
func provideMigrationRunner() fx.Annotated {
return fx.Annotated{
Group: bootstrap.FxCliRunnerGroup,
Target: newMigrationRunner,
}
}
type migrationRunnerIn struct {
fx.In
R *Registrar
V Versioner
DB *gorm.DB
DbCreators []data.DbCreator `group:"gorm_config"`
}
func newMigrationRunner(di migrationRunnerIn) bootstrap.CliRunner {
return func(ctx context.Context) error {
if len(di.DbCreators) > 0 {
order.SortStable(di.DbCreators, order.OrderedFirstCompare)
dbCreator := di.DbCreators[0]
if err := dbCreator.CreateDatabaseIfNotExist(ctx, di.DB); err != nil {
return err
}
}
return Migrate(ctx, di.R, di.V)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package migration
import (
"context"
"errors"
"fmt"
"gorm.io/gorm"
"io"
"io/fs"
"strings"
)
func migrationFuncFromTextFile(fs fs.FS, filePath string, db *gorm.DB) (MigrationFunc){
file, err := fs.Open(filePath)
if err != nil {
panic(errors.New(fmt.Sprintf("%s does not exist or is not a file", filePath)))
}
sql, err := io.ReadAll(file)
if err != nil {
panic(err)
}
return func(ctx context.Context) error {
for _, query := range strings.Split(string(sql), ";") {
query = strings.TrimSpace(query)
if query == "" {
continue
}
logger.Debugf("executing query %s", query)
result := db.WithContext(ctx).Exec(query)
if result.Error != nil {
return result.Error
}
}
return nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package migration
import (
"database/sql/driver"
"fmt"
"github.com/pkg/errors"
"strconv"
"strings"
)
type Version []int
func (v Version) Lt(other Version) bool {
maxLen := len(v)
if len(other) < maxLen {
maxLen = len(other)
}
for n := 0; n < maxLen; n++ {
if v[n] < other[n] {
return true
}
if v[n] > other[n] {
return false
}
}
return len(v) < len(other)
}
func (v Version) String() string {
var sb = strings.Builder{}
for _, v := range v {
if sb.Len() > 0 {
sb.WriteRune('.')
}
sb.WriteString(strconv.Itoa(v))
}
return sb.String()
}
func (v Version) Equals(o Version) bool {
if len(v) != len(o) {
return false
}
for i, n := range v {
if n != o[i] {
return false
}
}
return true
}
func (v *Version) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return v.scanString(string(src))
case string:
return v.scanString(src)
case nil:
*v = nil
return nil
}
return fmt.Errorf("pq: cannot convert %T to StringArray", src)
}
func (v Version) Value() (driver.Value, error) {
if v == nil {
return nil, nil
}
return v.String(), nil
}
func (v Version) GormDataType() string {
return "string"
}
func (v *Version) scanString(src string) error {
result, err := fromString(src)
*v = result
return err
}
func fromString(source string) (Version, error) {
parts := strings.Split(source, ".")
var numbers []int
if len(parts) == 0 {
return Version{}, errors.New("Version must have at least one numeric component")
}
for _, part := range parts {
if number, err := strconv.Atoi(part); err != nil {
return Version{}, errors.Wrap(err, "Cannot parse component as integer")
} else {
numbers = append(numbers, number)
}
}
return numbers, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opaaccess
import (
"context"
"github.com/cisco-open/go-lanai/pkg/opa"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"net/http"
)
// DecisionMakerWithOPA is an access.DecisionMakerFunc that utilize OPA engine
func DecisionMakerWithOPA(opts ...opa.RequestQueryOptions) access.DecisionMakerFunc {
return func(ctx context.Context, req *http.Request) (handled bool, decision error) {
e := opa.AllowRequest(ctx, req, opts...)
if e != nil {
return true, security.NewAccessDeniedError(e)
}
return true, nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/open-policy-agent/opa/sdk"
)
func contextWithOverriddenLogLevel(ctx context.Context, override *log.LoggingLevel) context.Context {
if override == nil {
return ctx
}
return logContextWithLevel(ctx, *override)
}
func handleDecisionResult(ctx context.Context, result *sdk.DecisionResult, rErr error, targetName string) (err error) {
var parsedResult interface{}
defer func() {
event := &resultEvent{
Result: parsedResult,
Deny: err != nil,
}
if result != nil {
event.ID = result.ID
}
if err == nil {
eventLogger(ctx, log.LevelDebug).WithKV(kLogDecisionReason, event).Printf("Allow [%v]", event.ID)
} else {
eventLogger(ctx, log.LevelDebug).WithKV(kLogDecisionReason, event).Printf("Deny [%v]", event.ID)
}
}()
switch {
case sdk.IsUndefinedErr(rErr):
parsedResult = "not true"
return errorWithTargetName(targetName)
case rErr != nil:
parsedResult = rErr
return ErrAccessDenied.WithMessage("unable to execute OPA query: %v", rErr)
}
parsedResult = result.Result
switch v := result.Result.(type) {
case bool:
if !v {
return errorWithTargetName(targetName)
}
default:
return ErrAccessDenied.WithMessage("unsupported OPA result type %T", result.Result)
}
return nil
}
func errorWithTargetName(targetName string) error {
if len(targetName) == 0 {
return ErrAccessDenied
}
return ErrAccessDenied.WithMessage("%s Access Denied", targetName)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/open-policy-agent/opa/sdk"
"time"
)
type QueryOptions func(q *Query)
type Query struct {
OPA *sdk.OPA
Policy string
InputCustomizers []InputCustomizer
RawInput interface{}
// LogLevel override decision log level when presented
LogLevel *log.LoggingLevel
}
func QueryWithPolicy(policy string) QueryOptions {
return func(q *Query) {
q.Policy = policy
}
}
func SilentQuery() QueryOptions {
var silent = log.LevelOff
return func(opt *Query) {
opt.LogLevel = &silent
}
}
func QueryWithInputCustomizer(customizer InputCustomizerFunc) QueryOptions {
return func(q *Query) {
q.InputCustomizers = append(q.InputCustomizers, customizer)
}
}
// Allow is generic API for querying policy. This function only populate minimum input data like authentication.
// For more specialized function, see AllowResource, AllowRequest, etc.
func Allow(ctx context.Context, opts ...QueryOptions) error {
query := Query{
OPA: EmbeddedOPA(),
InputCustomizers: embeddedOPA.inputCustomizers,
}
for _, fn := range opts {
fn(&query)
}
if len(query.Policy) == 0 {
return ErrInternal.WithMessage("policy is required for generic Allow function")
}
ctx = contextWithOverriddenLogLevel(ctx, query.LogLevel)
opaOpts, e := PrepareGenericDecisionQuery(ctx, &query)
if e != nil {
return ErrInternal.WithMessage(`error when preparing OPA input: %v`, e)
}
result, e := query.OPA.Decision(ctx, *opaOpts)
return handleDecisionResult(ctx, result, e, "")
}
func PrepareGenericDecisionQuery(ctx context.Context, query *Query) (*sdk.DecisionOptions, error) {
input, e := constructGenericDecisionInput(ctx, query)
if e != nil {
return nil, e
}
opts := sdk.DecisionOptions{
Now: time.Now(),
Path: query.Policy,
Input: input,
StrictBuiltinErrors: false,
}
//if data, e := json.Marshal(opts.Input); e != nil {
// eventLogger(ctx, log.LevelError).Printf("Input marshalling error: %v", e)
//} else {
// eventLogger(ctx, log.LevelDebug).Printf("Input: %s", data)
//}
return &opts, nil
}
func constructGenericDecisionInput(ctx context.Context, query *Query) (interface{}, error) {
if query.RawInput != nil {
return query.RawInput, nil
}
input := NewInput()
input.Authentication = NewAuthenticationClause()
for _, customizer := range query.InputCustomizers {
if e := customizer.Customize(ctx, input); e != nil {
return nil, e
}
}
return input, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/open-policy-agent/opa/sdk"
"time"
)
type ContextAwarePartialQueryMapper interface {
sdk.PartialQueryMapper
WithContext(ctx context.Context) sdk.PartialQueryMapper
Context() context.Context
}
type ResourceFilterOptions func(rf *ResourceFilter)
type ResourceFilter struct {
// OPA (Optional) instance to use for evaluation. Default to EmbeddedOPA()
OPA *sdk.OPA
// Query (Optional) OPA query to evaluate.
// Default to `data.resource.<resource_type>.filter_<resource_operation>`
Query string
// Unknowns (Required) List of unknown input fields for partial evaluation. Not providing "unknowns" would not
// result in immediate error, but very like result in access denial.
Unknowns []string
// QueryMapper (Optional) Custom sdk.PartialQueryMapper for translating result rego.PartialQueries.
// By default, partial result is *rego.PartialQueries. QueryMapper can translate it to other structure.
// e.g. SQL "Where" clause
QueryMapper sdk.PartialQueryMapper
// Delta (Optional) Resource's "changed-to" fields and values. Delta is only applicable to "write" operation.
// OPA policies may have rules on what values the resource's certain fields can be changed to.
Delta *ResourceValues
// ExtraData (Optional) any key-value pairs in ExtraData will be added into query input under `input.resource.*`
ExtraData map[string]interface{}
// InputCustomizers customizers to finalize/modify query input before evaluation
InputCustomizers []InputCustomizer
// RawInput overrides any input related options
RawInput interface{}
// LogLevel override decision log level when presented
LogLevel *log.LoggingLevel
}
func SilentResourceFilter() ResourceFilterOptions {
var silent = log.LevelOff
return func(opt *ResourceFilter) {
opt.LogLevel = &silent
}
}
func FilterResource(ctx context.Context, resType string, op ResourceOperation, opts ...ResourceFilterOptions) (*sdk.PartialResult, error) {
res := ResourceFilter{
OPA: EmbeddedOPA(),
InputCustomizers: embeddedOPA.inputCustomizers,
QueryMapper: &sdk.RawMapper{},
ExtraData: map[string]interface{}{},
}
for _, fn := range opts {
fn(&res)
}
if len(res.Query) == 0 {
res.Query = fmt.Sprintf("data.%s.%s.filter_%v", PackagePrefixResource, resType, op)
}
ctx = contextWithOverriddenLogLevel(ctx, res.LogLevel)
opaOpts, e := PrepareResourcePartialQuery(ctx, res.Query, resType, op, &res)
if e != nil {
return nil, ErrInternal.WithMessage(`error when preparing OPA input: %v`, e)
}
result, e := res.OPA.Partial(ctx, *opaOpts)
return handlePartialResult(ctx, result, e)
}
func PrepareResourcePartialQuery(ctx context.Context, policy string, resType string, op ResourceOperation, res *ResourceFilter) (*sdk.PartialOptions, error) {
input, e := constructResourcePartialInput(ctx, resType, op, res)
if e != nil {
return nil, e
}
mapper := res.QueryMapper
if v, ok := res.QueryMapper.(ContextAwarePartialQueryMapper); ok {
mapper = v.WithContext(ctx)
}
opts := sdk.PartialOptions{
Now: time.Now(),
Input: input,
Query: policy,
Unknowns: res.Unknowns,
Mapper: mapper,
}
//if data, e := json.Marshal(opts.Input); e != nil {
// eventLogger(ctx, log.LevelError).Printf("Input marshalling error: %v", e)
//} else {
// eventLogger(ctx, log.LevelDebug).Printf("Input: %s", data)
//}
return &opts, nil
}
func constructResourcePartialInput(ctx context.Context, resType string, op ResourceOperation, res *ResourceFilter) (interface{}, error) {
if res.RawInput != nil {
return res.RawInput, nil
}
input := NewInput()
input.Authentication = NewAuthenticationClause()
input.Resource = NewResourceClause(resType, op)
input.Resource.ExtraData = res.ExtraData
input.Resource.Delta = res.Delta
for _, customizer := range res.InputCustomizers {
if e := customizer.Customize(ctx, input); e != nil {
return nil, e
}
}
return input, nil
}
func handlePartialResult(ctx context.Context, result *sdk.PartialResult, rErr error) (_ *sdk.PartialResult, err error) {
var event partialResultEvent
defer func() {
if err == nil {
eventLogger(ctx, log.LevelDebug).WithKV(kLogPartialResult, event).Printf("Partial [%s]", event.ID)
} else {
eventLogger(ctx, log.LevelDebug).WithKV(kLogPartialReason, event).Printf("Deny Partial [%s]", event.ID)
}
}()
if result != nil {
event.ID = result.ID
}
if rErr != nil {
event.Err = rErr
switch {
case sdk.IsUndefinedErr(rErr):
return nil, ErrAccessDenied
case errors.Is(rErr, ErrQueriesNotResolved):
return nil, ErrAccessDenied.WithMessage("%s", rErr.Error())
default:
return nil, ErrAccessDenied.WithMessage("failed to perform partial evaluation: %v", rErr)
}
}
event.AST = (*partialQueriesLog)(result.AST)
return result, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/open-policy-agent/opa/sdk"
"net/http"
"time"
)
type RequestQueryOptions func(opt *RequestQuery)
type RequestQuery struct {
OPA *sdk.OPA
Policy string
ExtraData map[string]interface{}
InputCustomizers []InputCustomizer
// RawInput overrides any input related options
RawInput interface{}
// LogLevel override decision log level when presented
LogLevel *log.LoggingLevel
}
func RequestQueryWithPolicy(policy string) RequestQueryOptions {
return func(opt *RequestQuery) {
opt.Policy = policy
}
}
func SilentRequestQuery() RequestQueryOptions {
var silent = log.LevelOff
return func(opt *RequestQuery) {
opt.LogLevel = &silent
}
}
func AllowRequest(ctx context.Context, req *http.Request, opts ...RequestQueryOptions) error {
opt := RequestQuery{
OPA: EmbeddedOPA(),
InputCustomizers: embeddedOPA.inputCustomizers,
ExtraData: map[string]interface{}{},
}
for _, fn := range opts {
fn(&opt)
}
ctx = contextWithOverriddenLogLevel(ctx, opt.LogLevel)
opaOpts, e := PrepareRequestDecisionQuery(ctx, opt.Policy, req, &opt)
if e != nil {
return ErrInternal.WithMessage(`error when preparing OPA input: %v`, e)
}
result, e := opt.OPA.Decision(ctx, *opaOpts)
return handleDecisionResult(ctx, result, e, "API")
}
func PrepareRequestDecisionQuery(ctx context.Context, policy string, req *http.Request, opt *RequestQuery) (*sdk.DecisionOptions, error) {
input, e := constructRequestDecisionInput(ctx, req, opt)
if e != nil {
return nil, e
}
opts := sdk.DecisionOptions{
Now: time.Now(),
Path: policy,
Input: input,
StrictBuiltinErrors: false,
}
//if data, e := json.Marshal(opts.Input); e != nil {
// eventLogger(ctx, log.LevelError).Printf("Input marshalling error: %v", e)
//} else {
// eventLogger(ctx, log.LevelDebug).Printf("Input: %s", data)
//}
return &opts, nil
}
func constructRequestDecisionInput(ctx context.Context, req *http.Request, opt *RequestQuery) (interface{}, error) {
if opt.RawInput != nil {
return opt.RawInput, nil
}
input := NewInput()
input.Authentication = NewAuthenticationClause()
input.Request = NewRequestClause(req)
input.Request.ExtraData = opt.ExtraData
for _, customizer := range opt.InputCustomizers {
if e := customizer.Customize(ctx, input); e != nil {
return nil, e
}
}
return input, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/open-policy-agent/opa/sdk"
"time"
)
type ResourceQueryOptions func(res *ResourceQuery)
type ResourceQuery struct {
// OPA (Optional) Instance to use for evaluation. Default to EmbeddedOPA()
OPA *sdk.OPA
// Policy (Optional) OPA query/policy to evaluate.
// Default to `resource/<resource_type>/allow_<resource_operation>`
Policy string
// ResourceValues (Required) Resource's current fields and values that policy may be interested in
ResourceValues
// Delta (Optional) Resource's "changed-to" fields and values. Delta is only applicable to "write" operation.
// OPA policies may have rules on what values the resource's certain fields can be changed to.
Delta *ResourceValues
// InputCustomizers customizers to finalize/modify query input before evaluation
InputCustomizers []InputCustomizer
// RawInput overrides any input related options
RawInput interface{}
// LogLevel override decision log level when presented
LogLevel *log.LoggingLevel
}
func SilentResourceQuery() ResourceQueryOptions {
var silent = log.LevelOff
return func(opt *ResourceQuery) {
opt.LogLevel = &silent
}
}
func AllowResource(ctx context.Context, resType string, op ResourceOperation, opts ...ResourceQueryOptions) error {
res := ResourceQuery{
OPA: EmbeddedOPA(),
InputCustomizers: embeddedOPA.inputCustomizers,
ResourceValues: ResourceValues{ExtraData: map[string]interface{}{}},
}
for _, fn := range opts {
fn(&res)
}
if len(res.Policy) == 0 {
res.Policy = fmt.Sprintf("%s/%s/allow_%v", PackagePrefixResource, resType, op)
}
ctx = contextWithOverriddenLogLevel(ctx, res.LogLevel)
opaOpts, e := PrepareResourceDecisionQuery(ctx, res.Policy, resType, op, &res)
if e != nil {
return ErrInternal.WithMessage(`error when preparing OPA input: %v`, e)
}
result, e := res.OPA.Decision(ctx, *opaOpts)
return handleDecisionResult(ctx, result, e, "ResourceQuery")
}
func PrepareResourceDecisionQuery(ctx context.Context, policy string, resType string, op ResourceOperation, res *ResourceQuery) (*sdk.DecisionOptions, error) {
input, e := constructResourceDecisionInput(ctx, resType, op, res)
if e != nil {
return nil, e
}
opts := sdk.DecisionOptions{
Now: time.Now(),
Path: policy,
Input: input,
StrictBuiltinErrors: false,
}
//if data, e := json.Marshal(opts.Input); e != nil {
// eventLogger(ctx, log.LevelError).Printf("Input marshalling error: %v", e)
//} else {
// eventLogger(ctx, log.LevelDebug).Printf("Input: %s", data)
//}
return &opts, nil
}
func constructResourceDecisionInput(ctx context.Context, resType string, op ResourceOperation, res *ResourceQuery) (interface{}, error) {
if res.RawInput != nil {
return res.RawInput, nil
}
input := NewInput()
input.Authentication = NewAuthenticationClause()
input.Resource = NewResourceClause(resType, op)
input.Resource.CurrentResourceValues = CurrentResourceValues(res.ResourceValues)
input.Resource.Delta = res.Delta
for _, customizer := range res.InputCustomizers {
if e := customizer.Customize(ctx, input); e != nil {
return nil, e
}
}
return input, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opaactuator
import (
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/opa"
opaaccess "github.com/cisco-open/go-lanai/pkg/opa/access"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"regexp"
)
const RequestInputKeyEndpointID = `endpoint_id`
func NewAccessControlWithOPA(props actuator.SecurityProperties, opts ...opa.RequestQueryOptions) actuator.AccessControlCustomizer {
return actuator.AccessControlCustomizeFunc(func(ac *access.AccessControlFeature, epId string, paths []string) {
if len(paths) == 0 {
return
}
// configure request matchers
reqMatcher := pathToRequestPattern(paths[0])
for _, p := range paths[1:] {
reqMatcher = reqMatcher.Or(pathToRequestPattern(p))
}
switch {
case !isSecurityEnabled(epId, &props):
ac.Request(reqMatcher).PermitAll()
default:
opts = append(opts, func(opt *opa.RequestQuery) {
opt.ExtraData[RequestInputKeyEndpointID] = epId
})
ac.Request(reqMatcher).CustomDecisionMaker(opaaccess.DecisionMakerWithOPA(opts...))
}
})
}
var pathVarRegex = regexp.MustCompile(`:[a-zA-Z0-9\-_]+`)
// pathToRequestPattern convert path variables to wildcard request pattern
// "/path/to/:any/endpoint" would converted to "/path/to/*/endpoint
func pathToRequestPattern(path string) web.RequestMatcher {
patternStr := pathVarRegex.ReplaceAllString(path, "*")
return matcher.RequestWithPattern(patternStr)
}
func isSecurityEnabled(epId string, properties *actuator.SecurityProperties) bool {
enabled := properties.EnabledByDefault
if props, ok := properties.Endpoints[epId]; ok {
if props.Enabled != nil {
enabled = *props.Enabled
}
}
return enabled
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opaactuator
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"github.com/cisco-open/go-lanai/pkg/opa"
)
func NewHealthDisclosureControlWithOPA(opts ...opa.QueryOptions) health.DisclosureControl {
return health.DisclosureControlFunc(func(ctx context.Context) bool {
e := opa.Allow(ctx, opts...)
return e == nil
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
)
const SpecialScopeAdmin = `admin`
type MockedHealthIndicator struct {
Status health.Status
Description string
Details map[string]interface{}
}
func NewMockedHealthIndicator() *MockedHealthIndicator {
return &MockedHealthIndicator{
Status: health.StatusUp,
Description: "mocked",
Details: map[string]interface{}{
"key": "value",
},
}
}
func (i *MockedHealthIndicator) Name() string {
return "test"
}
func (i *MockedHealthIndicator) Health(_ context.Context, opts health.Options) health.Health {
ret := health.CompositeHealth{
SimpleHealth: health.SimpleHealth{
Stat: i.Status,
Desc: i.Description,
},
}
if opts.ShowComponents {
detailed := health.DetailedHealth{
SimpleHealth: health.SimpleHealth{
Stat: i.Status,
Desc: "mocked detailed",
},
}
if opts.ShowDetails {
detailed.Details = i.Details
}
ret.Components = map[string]health.Health{
"simple": health.SimpleHealth{
Stat: i.Status,
Desc: "mocked simple",
},
"detailed": detailed,
}
}
return ret
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"bytes"
"context"
"dario.cat/mergo"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/open-policy-agent/opa/download"
opakeys "github.com/open-policy-agent/opa/keys"
"github.com/open-policy-agent/opa/plugins/bundle"
opadiscovery "github.com/open-policy-agent/opa/plugins/discovery"
opalogs "github.com/open-policy-agent/opa/plugins/logs"
oparest "github.com/open-policy-agent/opa/plugins/rest"
opastatus "github.com/open-policy-agent/opa/plugins/status"
opacache "github.com/open-policy-agent/opa/topdown/cache"
"io"
"net/url"
"time"
)
const defaultServerName = `opa-bundle-service`
type ConfigCustomizer interface {
Customize(ctx context.Context, cfg *Config)
}
// Config is a subset OPA Config with typed field
// see OPA's Config.Config and Config.ParseConfig
type Config struct {
Services map[string]*oparest.Config `json:"services,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
Discovery *opadiscovery.Config `json:"discovery,omitempty"`
Bundles map[string]*bundle.Source `json:"bundles,omitempty"`
DecisionLogs *opalogs.Config `json:"decision_logs,omitempty"`
Status *opastatus.Config `json:"status,omitempty"`
Plugins map[string]interface{} `json:"plugins,omitempty"`
Keys map[string]*opakeys.Config `json:"keys,omitempty"`
DefaultDecision *string `json:"default_decision,omitempty"`
DefaultAuthorizationDecision *string `json:"default_authorization_decision,omitempty"`
Caching *opacache.Config `json:"caching,omitempty"`
NDBuiltinCache bool `json:"nd_builtin_cache,omitempty"`
PersistenceDirectory *string `json:"persistence_directory,omitempty"`
DistributedTracing *distributedTracingConfig `json:"distributed_tracing,omitempty"`
Storage *storageConfig `json:"storage,omitempty"`
ExtraConfig map[string]interface{} `json:"-"`
}
func (c Config) MarshalJSON() ([]byte, error) {
type config Config
return marshalMergedJSON(config(c), c.ExtraConfig, minimizeMap)
}
func (c Config) MarshalText() ([]byte, error) {
return json.Marshal(c)
}
func (c Config) JSONReader(ctx context.Context) (io.Reader, error) {
var buf bytes.Buffer
if e := json.NewEncoder(&buf).Encode(&c); e != nil {
return nil, e
}
logger.WithContext(ctx).Debugf("OPA Config: %s", buf.Bytes())
return &buf, nil
}
// see OPA's internal distributedtracing.distributedTracingConfig (internal/distributedtracing/distributedtracing.go)
type distributedTracingConfig struct {
Type string `json:"type,omitempty"`
Address string `json:"address,omitempty"`
ServiceName string `json:"service_name,omitempty"`
SampleRatePercentage *int `json:"sample_percentage,omitempty"`
EncryptionScheme string `json:"encryption,omitempty"`
EncryptionSkipVerify *bool `json:"allow_insecure_tls,omitempty"`
TLSCertFile string `json:"tls_cert_file,omitempty"`
TLSCertPrivateKeyFile string `json:"tls_private_key_file,omitempty"`
TLSCACertFile string `json:"tls_ca_cert_file,omitempty"`
}
type storageConfig struct {
Disk *diskConfig `json:"disk,omitempty"`
}
// see OPA's disk.cfg (disk/Config.go)
type diskConfig struct {
Dir string `json:"directory"`
AutoCreate bool `json:"auto_create"`
Partitions []string `json:"partitions"`
Badger string `json:"badger"`
}
// LoadConfig create config and combine values from defaults and properties
func LoadConfig(ctx context.Context, props Properties, customizers ...ConfigCustomizer) (*Config, error) {
cfg := Config{
Plugins: map[string]interface{}{},
ExtraConfig: map[string]interface{}{},
}
if e := applyProperties(&props, &cfg); e != nil {
return nil, e
}
for _, customizer := range customizers {
customizer.Customize(ctx, &cfg)
}
return &cfg, nil
}
func applyProperties(props *Properties, cfg *Config) error {
// service
serverName := props.Server.Name
if len(serverName) == 0 {
serverName = defaultServerName
}
if _, e := url.Parse(props.Server.URL); e != nil {
return fmt.Errorf(`invalid OPA server URL: %v`, e)
}
cfg.Services = map[string]*oparest.Config{
serverName: {
Name: serverName,
URL: props.Server.URL,
//AllowInsecureTLS: true,
},
}
// decision logs
if props.Logging.DecisionLogsLevel != log.LevelOff {
cfg.Plugins[pluginNameDecisionLogger] = props.Logging
cfg.DecisionLogs = &opalogs.Config{
Plugin: utils.ToPtr(pluginNameDecisionLogger),
}
}
// bundles
cfg.Bundles = map[string]*bundle.Source{}
for k := range props.Bundles {
v := props.Bundles[k]
polling := props.Server.PollingProperties
if e := mergo.Merge(&polling, &v.PollingProperties, mergo.WithOverride); e != nil {
return fmt.Errorf("unable to merge polling properties of bundle [%s]: %v", k, e)
}
cfg.Bundles[k] = &bundle.Source{
Config: download.Config{
Trigger: nil,
Polling: download.PollingConfig{
MinDelaySeconds: asSeconds(polling.PollingMinDelay),
MaxDelaySeconds: asSeconds(polling.PollingMaxDelay),
LongPollingTimeoutSeconds: asSeconds(polling.LongPollingTimeout),
},
},
Service: serverName,
Resource: v.Path,
}
}
return nil
}
func asSeconds(duration *utils.Duration) *int64 {
if duration == nil {
return nil
}
secs := int64(time.Duration(*duration) / time.Second)
return &secs
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"net/http"
"net/url"
)
/********************
Constants
********************/
const (
PackagePrefixResource = `resource`
)
/********************
Common Inputs
********************/
const (
InputPrefixRoot = `input`
InputPrefixAuthentication = `auth`
InputPrefixRequest = `request`
InputPrefixResource = `resource`
)
type Input struct {
Authentication *AuthenticationClause `json:"auth,omitempty"`
Request *RequestClause `json:"request,omitempty"`
Resource *ResourceClause `json:"resource,omitempty"`
ExtraData map[string]interface{} `json:"-"`
}
func (c Input) MarshalJSON() ([]byte, error) {
type clause Input
return marshalMergedJSON(clause(c), c.ExtraData)
}
func NewInput() *Input {
return &Input{
ExtraData: make(map[string]interface{}),
}
}
type InputCustomizer interface {
Customize(ctx context.Context, input *Input) error
}
type InputCustomizerFunc func(ctx context.Context, input *Input) error
func (fn InputCustomizerFunc) Customize(ctx context.Context, input *Input) error {
return fn(ctx, input)
}
/*****************************
Common Identity Blocks
*****************************/
type AuthenticationClause struct {
// Required fields
UserID string `json:"user_id"`
Permissions []string `json:"permissions"`
// Optional fields
Username string `json:"username,omitempty"`
TenantID string `json:"tenant_id,omitempty"`
ProviderID string `json:"provider_id,omitempty"`
Roles []string `json:"roles,omitempty"`
AccessibleTenants []string `json:"accessible_tenants,omitempty"`
Client *OAuthClientClause `json:"client"`
ExtraData map[string]interface{} `json:"-"`
}
type OAuthClientClause struct {
ClientID string `json:"client_id"`
GrantType string `json:"grant_type,omitempty"`
Scopes []string `json:"scopes"`
}
func (c AuthenticationClause) MarshalJSON() ([]byte, error) {
type clause AuthenticationClause
return marshalMergedJSON(clause(c), c.ExtraData)
}
func NewAuthenticationClause() *AuthenticationClause {
return &AuthenticationClause{
ExtraData: map[string]interface{}{},
}
}
/**************************
Common ResourceQuery Blocks
**************************/
type RequestClause struct {
Scheme string `json:"scheme,omitempty"`
Path string `json:"path,omitempty"`
Method string `json:"method,omitempty"`
Header http.Header `json:"header,omitempty"`
Query url.Values `json:"query,omitempty"`
ExtraData map[string]interface{} `json:"-"`
}
func (c RequestClause) MarshalJSON() ([]byte, error) {
type clause RequestClause
return marshalMergedJSON(clause(c), c.ExtraData)
}
func NewRequestClause(req *http.Request) *RequestClause {
return &RequestClause{
Scheme: req.URL.Scheme,
Path: req.URL.Path,
Method: req.Method,
Header: req.Header,
Query: req.URL.Query(),
}
}
type ResourceOperation string
const (
OpRead ResourceOperation = `read`
OpWrite ResourceOperation = `write`
OpCreate ResourceOperation = `create`
OpDelete ResourceOperation = `delete`
)
type ResourceValues struct {
TenantID string `json:"tenant_id,omitempty"`
TenantPath []string `json:"tenant_path,omitempty"`
OwnerID string `json:"owner_id,omitempty"`
Sharing map[string][]ResourceOperation `json:"sharing,omitempty"`
ExtraData map[string]interface{} `json:"-"`
}
func (c ResourceValues) MarshalJSON() ([]byte, error) {
type clause ResourceValues
return marshalMergedJSON(clause(c), c.ExtraData)
}
type CurrentResourceValues ResourceValues
type ResourceClause struct {
CurrentResourceValues
Type string `json:"type"`
Operation ResourceOperation `json:"op"`
Delta *ResourceValues `json:"delta,omitempty"`
}
func NewResourceClause(resType string, op ResourceOperation) *ResourceClause {
return &ResourceClause{
Type: resType,
Operation: op,
}
}
func (c ResourceClause) MarshalJSON() ([]byte, error) {
type clause ResourceClause
return marshalMergedJSON(clause(c), c.ExtraData)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package constraints
import (
"database/sql/driver"
"github.com/cisco-open/go-lanai/pkg/data/types/pqx"
"github.com/cisco-open/go-lanai/pkg/opa"
"github.com/google/uuid"
)
const (
SharedPermissionRead = SharedPermission(opa.OpRead)
SharedPermissionWrite = SharedPermission(opa.OpWrite)
SharedPermissionDelete = SharedPermission(opa.OpDelete)
)
type SharedPermission opa.ResourceOperation
// Sharing is a Model type that stores mapping between user IDs and a list of allowed permissions as JSONB map
// This type works with OPA sharing policy
type Sharing map[uuid.UUID][]SharedPermission
// Value implements driver.Valuer
func (s Sharing) Value() (driver.Value, error) {
return pqx.JsonbValue(s)
}
// Scan implements sql.Scanner
func (s *Sharing) Scan(src interface{}) error {
return pqx.JsonbScan(src, s)
}
func (s Sharing) GormDataType() string {
return "jsonb"
}
func (s Sharing) Share(userID uuid.UUID, perms ...SharedPermission) {
if len(perms) == 0 {
delete(s, userID)
} else {
s[userID] = perms
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opadata
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/opa"
"strings"
)
/*****************
Constants
*****************/
const (
DefaultQueryTemplate = `allow_%s`
DefaultPartialQueryTemplate = `filter_%s`
)
/*****************
Errors
*****************/
var (
ErrQueryTranslation = opa.NewError(`generic query translation error`)
ErrUnsupportedUsage = opa.NewError(`generic unsupported usage error`)
)
/*****************
Tag
*****************/
const (
TagOPA = `opa`
TagDelimiter = `,`
TagAssignment = `:`
TagValueIgnore = "-"
TagKeyInputField = `field`
TagKeyInputFieldAlt = `input`
TagKeyResourceType = `type`
TagKeyOPAPackage = `package`
)
// OPATag supported key-value pairs in `opa` tag.
// `opa` tag is in format of `opa:"<key>:<value>, [<more_keys>:<more_values>, ...]".
// Unless specified, each key-value pair only takes effect on either "to-be-filtered-by" model fields (Model Fields)
// or FilteredModel (regardless if embedded or as a field), but not both.
type OPATag struct {
// InputField Required on "to-be-filtered-by" model fields. Specify mappings between model field and OPA input fields.
// e.g. `opa:"field:myProperty"` translate to `input.resource.myProperty` in OPA input
InputField string
// ResType Required on FilteredModel. This value contributes to both OPA query and OPA input:
// - ResType is set to OPA input as `input.resource.type`
// - Unless OPAPackage or Policies is specified, ResType is also part of OPA query:
// "data.resource.{{RestType}}.<filter|allow>_{{DBOperationFlag}}"
ResType string
// OPAPackage Optional on FilteredModel. Used to overwrite default OPA query.
// Resulting query is "data.{{OPAPackage}}.<filter|allow>_{{DBOperationFlag}}"
// e.g. `opa:"type:my_res, package:my.res" -> the OPA query is "data.my.res.filter_{{DBOperationFlag}}"
OPAPackage string
// Policies Optional on FilteredModel. Fine control of OPA queries for each type of DB operation.
// - If set to "-", the corresponding DB operation is disabled for data-filtering.
// e.g. `opa:"type:my_res, read:-"` disables OPA data filtering for read operations (SELECT statements)
// - If set to any other non-empty string, it's used to construct OPA query
// e.g. `opa:"type:my_res, read:my_custom_filter"` -> OPA query "data.resource.my_res.my_custom_filter" is used for read operations.
Policies map[DBOperationFlag]string
// mode bitwise flags for enabled/disabled DB operations
mode policyMode
}
func (t *OPATag) UnmarshalText(data []byte) error {
// setup default
*t = OPATag{
mode: defaultPolicyMode,
}
// parse kv pairs
terms := strings.Split(string(data), TagDelimiter)
for _, term := range terms {
term = strings.TrimSpace(term)
if len(term) == 0 {
continue
}
kv := strings.SplitN(term, TagAssignment, 2)
var v string
switch len(kv) {
case 1:
v = `true`
case 2:
v = strings.TrimSpace(kv[1])
default:
return fmt.Errorf(`invalid "opa" tag format, expect "key:model", but got "%s"`, term)
}
k := strings.TrimSpace(kv[0])
switch k {
case TagKeyInputField, TagKeyInputFieldAlt:
t.InputField = v
case TagKeyResourceType:
t.ResType = v
case TagKeyOPAPackage:
t.OPAPackage = v
default:
if e := t.parsePolicy(kv); e == nil {
continue
}
return ErrUnsupportedUsage.WithMessage(`invalid "opa" tag, unrecognized key "%s"`, k)
}
}
return nil
}
func (t *OPATag) parsePolicy(kv []string) error {
if len(kv) != 2 {
return fmt.Errorf(`invalid policy, expect <mode>%s<policy_name>`, TagAssignment)
}
var flag DBOperationFlag
if e := flag.UnmarshalText([]byte(strings.TrimSpace(kv[0]))); e != nil {
return e
}
if t.Policies == nil {
t.Policies = map[DBOperationFlag]string{}
}
t.Policies[flag] = strings.TrimSpace(kv[1])
if kv[1] == TagValueIgnore {
t.mode = t.mode & ^policyMode(flag)
} else {
t.mode = t.mode | policyMode(flag)
}
return nil
}
// Queries normalized queries for OPA.
// By default, queries are
func (t *OPATag) Queries() map[DBOperationFlag]string {
//TODO
return t.Policies
}
/********************
Flags and Mode
********************/
const (
DBOperationFlagCreate DBOperationFlag = 1 << iota
DBOperationFlagRead
DBOperationFlagUpdate
DBOperationFlagDelete
)
const (
dbOpTextCreate = `create`
dbOpTextRead = `read`
dbOpTextUpdate = `update`
dbOpTextDelete = `delete`
)
// DBOperationFlag bitwise Flag of tenancy flag mode
type DBOperationFlag uint
func (f DBOperationFlag) MarshalText() ([]byte, error) {
switch f {
case DBOperationFlagCreate:
return []byte(dbOpTextCreate), nil
case DBOperationFlagRead:
return []byte(dbOpTextRead), nil
case DBOperationFlagUpdate:
return []byte(dbOpTextUpdate), nil
case DBOperationFlagDelete:
return []byte(dbOpTextDelete), nil
}
return []byte{}, nil
}
func (f *DBOperationFlag) UnmarshalText(data []byte) error {
switch v := string(data); v {
case dbOpTextCreate:
*f = DBOperationFlagCreate
case dbOpTextRead:
*f = DBOperationFlagRead
case dbOpTextUpdate:
*f = DBOperationFlagUpdate
case dbOpTextDelete:
*f = DBOperationFlagDelete
default:
return fmt.Errorf("unrecognized DB operation flag '%s'", string(data))
}
return nil
}
const (
defaultPolicyMode = policyMode(DBOperationFlagCreate | DBOperationFlagRead | DBOperationFlagUpdate | DBOperationFlagDelete)
)
// policyMode enum of policyMode
type policyMode uint
//goland:noinspection GoMixedReceiverTypes
func (m policyMode) hasFlags(flags ...DBOperationFlag) bool {
for _, flag := range flags {
if m&policyMode(flag) == 0 {
return false
}
}
return true
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"embed"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/test/mocks"
"github.com/cisco-open/go-lanai/test/sectest"
testutils "github.com/cisco-open/go-lanai/test/utils"
"github.com/google/uuid"
)
//go:embed *.sql *.yml uuid_pool.txt
var ModelDataFS embed.FS
var (
MockedAdminId = uuid.MustParse("710e8219-ed8d-474e-8f7d-96b27e46dba9")
MockedUserId1 = uuid.MustParse("595959e4-8803-4ab1-8acf-acfb92bb7322")
MockedUserId2 = uuid.MustParse("9a901c91-a3d6-4d39-9adf-34e74bb32de2")
MockedUserId3 = uuid.MustParse("e212a869-b636-4dc6-83db-e1ccd59e5e0e")
MockedRootTenantId = uuid.MustParse("23967dfe-d90f-4e1b-9406-e2df6685f232")
MockedTenantIdA = uuid.MustParse("d8423acc-28cb-4209-95d6-089de7fb27ef")
MockedTenantIdB = uuid.MustParse("37b7181a-0892-4706-8f26-60d286b63f14")
MockedTenantIdA1 = uuid.MustParse("be91531e-ca96-46eb-aea6-b7e0e2a50e21")
MockedTenantIdA2 = uuid.MustParse("b50c18d9-1741-49bd-8536-30943dfffb45")
MockedTenantIdB1 = uuid.MustParse("1513b015-6a7d-4de3-8b4f-cbb090ac126d")
MockedTenantIdB2 = uuid.MustParse("b21445de-9192-45de-acd7-91745ab4cc13")
)
/*************************
ID Pool
*************************/
func NewUUIDPool() (*testutils.UUIDPool, error) {
return testutils.NewUUIDPool(ModelDataFS, "uuid_pool.txt")
}
/*************************
Tenancy
*************************/
func ProvideMockedTenancyAccessor() tenancy.Accessor {
tenancyRelationship := []mocks.TenancyRelation{
{Parent: MockedRootTenantId, Child: MockedTenantIdA},
{Parent: MockedRootTenantId, Child: MockedTenantIdB},
{Parent: MockedTenantIdA, Child: MockedTenantIdA1},
{Parent: MockedTenantIdA, Child: MockedTenantIdA2},
{Parent: MockedTenantIdB, Child: MockedTenantIdB1},
{Parent: MockedTenantIdB, Child: MockedTenantIdB2},
}
return mocks.NewMockTenancyAccessor(tenancyRelationship, MockedRootTenantId)
}
/*************************
Users
*************************/
func ContextWithSecurityMock(parent context.Context, mockOpts ...sectest.SecurityMockOptions) context.Context {
return sectest.ContextWithSecurity(parent, sectest.MockedAuthentication(mockOpts...))
}
func AdminSecurityOptions(tenantId ...uuid.UUID) sectest.SecurityMockOptions {
return func(d *sectest.SecurityDetailsMock) {
d.Username = "admin"
d.UserId = MockedAdminId.String()
d.TenantExternalId = "Root Tenant"
d.Permissions = utils.NewStringSet("VIEW", "MANAGE")
d.Roles = utils.NewStringSet("ADMIN")
d.Tenants = utils.NewStringSet(MockedRootTenantId.String())
d.TenantId = MockedRootTenantId.String()
if len(tenantId) != 0 {
d.TenantId = tenantId[0].String()
d.Tenants.Add(d.TenantId)
}
}
}
func User1SecurityOptions(tenantId ...uuid.UUID) sectest.SecurityMockOptions {
return func(d *sectest.SecurityDetailsMock) {
d.Username = "user1"
d.UserId = MockedUserId1.String()
d.TenantExternalId = "Tenant A"
d.Permissions = utils.NewStringSet("NO_VIEW")
d.Roles = utils.NewStringSet("USER")
d.Tenants = utils.NewStringSet(MockedTenantIdA.String())
d.TenantId = MockedTenantIdA1.String()
if len(tenantId) != 0 {
d.TenantId = tenantId[0].String()
d.Tenants.Add(d.TenantId)
}
}
}
func User2SecurityOptions(tenantId ...uuid.UUID) sectest.SecurityMockOptions {
return func(d *sectest.SecurityDetailsMock) {
d.Username = "user2"
d.UserId = MockedUserId2.String()
d.TenantExternalId = "Tenant B"
d.Permissions = utils.NewStringSet("NO_VIEW")
d.Roles = utils.NewStringSet("USER")
d.Tenants = utils.NewStringSet(MockedTenantIdB.String())
d.TenantId = MockedTenantIdB1.String()
if len(tenantId) != 0 {
d.TenantId = tenantId[0].String()
d.Tenants.Add(d.TenantId)
}
}
}
func User3SecurityOptions(tenantId ...uuid.UUID) sectest.SecurityMockOptions {
return func(d *sectest.SecurityDetailsMock) {
d.Username = "user3"
d.UserId = MockedUserId3.String()
d.TenantExternalId = "Tenant A"
d.Permissions = utils.NewStringSet("NO_VIEW")
d.Roles = utils.NewStringSet("USER")
d.Tenants = utils.NewStringSet(MockedTenantIdA.String())
d.TenantId = MockedTenantIdA1.String()
if len(tenantId) != 0 {
d.TenantId = tenantId[0].String()
d.Tenants.Add(d.TenantId)
}
}
}
func ExtraPermsSecurityOptions(permissions ...string) sectest.SecurityMockOptions {
return func(d *sectest.SecurityDetailsMock) {
d.Permissions.Add(permissions...)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opadata
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"gorm.io/gorm/schema"
"reflect"
"strings"
"sync"
)
var (
metadataCache = &sync.Map{}
schemaCache = &sync.Map{}
)
var (
typeFilteredModel = reflect.TypeOf(FilteredModel{})
typeFilter = reflect.TypeOf(Filter{})
typeGenericMap = reflect.TypeOf(map[string]interface{}{})
policyMarkerTypes = utils.NewSet(
typeFilteredModel, reflect.PointerTo(typeFilteredModel),
typeFilter, reflect.PointerTo(typeFilter),
)
)
const (
errTmplEmbeddedStructNotFound = `FilteredModel or Filter not found in model struct [%s]. Tips: embedding 'FilteredModel'' or having field with type 'Filter'' is required for any OPA DB usage`
errTmplOPATagNotFound = `'opa' tag is not found on Embedded PolicyAware in policyTarget [%s]. Tips: the Embedded PolicyAware should have 'opa' tag with at least resource type defined`
)
/*******************
Metadata
*******************/
type TaggedField struct {
schema.Field
OPATag OPATag
RelationPath TaggedRelationPath
}
func (f TaggedField) InputField() string {
if len(f.RelationPath) == 0 {
return f.OPATag.InputField
}
return f.RelationPath.InputField() + "." + f.OPATag.InputField
}
type TaggedRelationship struct {
schema.Relationship
OPATag OPATag
}
type TaggedRelationPath []*TaggedRelationship
func (path TaggedRelationPath) InputField() string {
names := make([]string, len(path))
for i := range path {
names[i] = path[i].OPATag.InputField
}
return strings.Join(names, ".")
}
// Metadata contains all static/declarative information of a model struct.
type Metadata struct {
OPATag
Fields map[string]*TaggedField
Schema *schema.Schema
}
func newMetadata(s *schema.Schema) (*Metadata, error) {
fields, e := collectAllFields(s)
if e != nil {
return nil, e
}
tag, e := parseTag(s)
if e != nil {
return nil, e
}
return &Metadata{
OPATag: *tag,
Fields: fields,
Schema: s,
}, nil
}
func resolveMetadata(model interface{}) (*Metadata, error) {
s, e := schema.Parse(model, schemaCache, schema.NamingStrategy{})
if e != nil {
return nil, e
}
return loadMetadata(s)
}
func loadMetadata(s *schema.Schema) (*Metadata, error) {
key := s.ModelType
v, ok := metadataCache.Load(key)
if ok {
return v.(*Metadata), nil
}
newV, e := newMetadata(s)
if e != nil {
return nil, e
}
v, _ = metadataCache.LoadOrStore(key, newV)
return v.(*Metadata), nil
}
func collectAllFields(s *schema.Schema) (ret map[string]*TaggedField, err error) {
ret = map[string]*TaggedField{}
if err = collectFields(s, ret); err != nil {
return
}
for _, r := range s.Relationships.Relations {
if err = collectRelationship(r, nil, utils.NewSet(), ret); err != nil {
return
}
}
return
}
func collectFields(s *schema.Schema, dest map[string]*TaggedField) error {
for _, f := range s.Fields {
if tag, ok := f.Tag.Lookup(TagOPA); ok {
if len(f.DBName) == 0 {
continue
}
if f.PrimaryKey && len(s.PrimaryFields) == 1 {
return ErrUnsupportedUsage.WithMessage(`"%s" tag cannot be used on single primary key`, TagOPA)
}
tagged := TaggedField{
Field: *f,
}
switch e := tagged.OPATag.UnmarshalText([]byte(tag)); {
case e != nil:
return ErrUnsupportedUsage.WithMessage(`invalid "%s" tag on %s.%s: %v`, TagOPA, s.Name, f.Name, e)
case len(tagged.OPATag.InputField) == 0:
return ErrUnsupportedUsage.WithMessage(`invalid "%s" tag on %s.%s: "%s" or "%s" is required`, TagOPA, s.Name, f.Name, TagKeyInputField, TagKeyInputFieldAlt)
}
dest[tagged.OPATag.InputField] = &tagged
}
}
return nil
}
func collectRelationship(r *schema.Relationship, path TaggedRelationPath, visited utils.Set, dest map[string]*TaggedField) error {
tag, ok := r.Field.Tag.Lookup(TagOPA)
if !ok || visited.Has(r.FieldSchema) {
return nil
}
visited.Add(r.FieldSchema)
// parse OPA tag of given relation
taggedR := TaggedRelationship{
Relationship: *r,
}
switch e := taggedR.OPATag.UnmarshalText([]byte(tag)); {
case e != nil:
return ErrUnsupportedUsage.WithMessage(`invalid "%s" tag on %s.%s: %v`, TagOPA, r.Schema.Name, r.Field.Name, e)
case len(taggedR.OPATag.InputField) == 0:
return ErrUnsupportedUsage.WithMessage(`invalid "%s" tag on %s.%s: "%s" or "%s" is required`, TagOPA, r.Schema.Name, r.Field.Name, TagKeyInputField, TagKeyInputFieldAlt)
}
path = append(path, &taggedR)
// collect fields of relationship's fields
fields := map[string]*TaggedField{}
if e := collectFields(r.FieldSchema, fields); e != nil {
return e
}
for _, tagged := range fields {
tagged.RelationPath = make([]*TaggedRelationship, len(path))
copy(tagged.RelationPath, path)
dest[tagged.InputField()] = tagged
}
// recursively collect fields of relationship
for _, r := range r.FieldSchema.Relationships.Relations {
if e := collectRelationship(r, path, visited, dest); e != nil {
return e
}
}
return nil
}
func parseTag(s *schema.Schema) (*OPATag, error) {
f, ok := findMarkerField(s.ModelType)
if !ok {
return nil, fmt.Errorf(errTmplEmbeddedStructNotFound, s.Name)
}
if e := validateMarkerField(s.ModelType, f); e != nil {
return nil, e
}
tag, ok := f.Tag.Lookup(TagOPA)
if !ok {
return nil, fmt.Errorf(errTmplOPATagNotFound, s.Name)
}
var parsed OPATag
if e := parsed.UnmarshalText([]byte(tag)); e != nil {
return nil, e
}
switch {
case len(parsed.ResType) == 0:
return nil, fmt.Errorf(errTmplOPATagNotFound, s.Name)
}
return &parsed, nil
}
// findMarkerField recursively find tag of marker types
// result is undefined if given type and Embedded type are not Struct
func findMarkerField(typ reflect.Type) (reflect.StructField, bool) {
count := typ.NumField()
for i := 0; i < count; i++ {
f := typ.Field(i)
if policyMarkerTypes.Has(f.Type) {
return f, true
}
if f.Anonymous {
if field, ok := findMarkerField(f.Type); ok {
field.Index = append(f.Index, field.Index...)
return field, ok
}
}
}
return reflect.StructField{}, false
}
func validateMarkerField(typ reflect.Type, field reflect.StructField) error {
_, ok := field.Tag.Lookup("gorm")
if !field.Anonymous && !ok {
return fmt.Errorf(`gorm:"-" tag is required on Filter field`)
}
for i := range field.Index {
f := typ.FieldByIndex(field.Index[:i+1])
if !f.Anonymous {
continue
}
if _, ok := f.Tag.Lookup("gorm"); ok {
return fmt.Errorf(`"gorm" tag is not allowed on embedded struct containing FilteredModel`)
}
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opadata
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/opa"
"gorm.io/gorm"
"reflect"
"strings"
)
//var logger = log.New("OPA.Data")
/**********************
Global Functions
**********************/
// ResolveResource parse given model and resolve resource type and resource values using "opa" tags
// Typically used together with opa.AllowResource as manual policy enforcement.
// ModelType should be model struct with FilteredModel and valid "opa" tags.
// Note: resValues could be nil if all OPA related values are zero
func ResolveResource[ModelType any](model *ModelType) (resType string, resValues *opa.ResourceValues, err error) {
rv := reflect.ValueOf(model)
if rv.Kind() != reflect.Pointer || rv.Elem().Kind() != reflect.Struct {
return "", nil, ErrUnsupportedUsage.WithMessage(`unable to resolve metadata of "%T": model need to be a struct`, model)
}
var meta *Metadata
if meta, err = resolveMetadata(model); err != nil {
return "", nil, ErrUnsupportedUsage.WithMessage(`unable to resolve metadata of "%T": %v`, model, err)
}
resType = meta.ResType
target := policyTarget{
meta: meta,
modelPtr: rv,
modelValue: rv.Elem(),
model: model,
}
if resValues, err = target.toResourceValues(); err != nil {
return "", nil, ErrUnsupportedUsage.WithMessage(`unable to extract OPA resource values of "%T": %v`, model, err)
}
return
}
/********************
GORM Scopes
********************/
// SkipFiltering is used as a scope for gorm.DB to skip policy-based data filtering
// e.g. db.WithContext(ctx).Scopes(SkipFiltering()).Find(...)
// Using this scope without context would panic
func SkipFiltering() func(*gorm.DB) *gorm.DB {
return FilterByPolicies(0)
}
// FilterByPolicies is used as a scope for gorm.DB to override policy-based data filtering.
// The specified operations are enabled, and the rest are disabled
// e.g. db.WithContext(ctx).Scopes(FilterByPolicies(DBOperationFlagRead)).Find(...)
// Using this scope without context would panic
func FilterByPolicies(flags ...DBOperationFlag) func(*gorm.DB) *gorm.DB {
return func(tx *gorm.DB) *gorm.DB {
if tx.Statement.Context == nil {
panic("FilterByPolicies scope is used without context")
}
var mode policyMode
for _, flag := range flags {
mode = mode | policyMode(flag)
}
ctx := context.WithValue(tx.Statement.Context, ckFilterMode{}, mode)
tx.Statement.Context = ctx
return tx
}
}
// FilterWithQueries is used as a scope for gorm.DB to override policy-based data filtering.
// Used to customize queries of specified DB operation. Additional DBOperationFlag-string pairs can be provided.
// e.g. db.WithContext(ctx).Scopes(FilterWithQueries(DBOperationFlagRead, "resource.type.allow_read")).Find(...)
// Important: This scope accept FULL QUERY including policy package.
// Notes:
// - It's recommended to use dotted format without leading "data.". FilteredModel would adjust the format based on operation.
// e.g. "resource.type.allow_read"
// - This scope doesn't enable/disable data-filtering. It only overrides queries set in tag.
// - Using this scope without context would panic
// - Having incorrect parameters cause panic
func FilterWithQueries(op DBOperationFlag, query string, more ...interface{}) func(*gorm.DB) *gorm.DB {
policies := map[DBOperationFlag]string{op: query}
for i := range more {
if op, ok := more[i].(DBOperationFlag); ok && i + 1 < len(more) {
if v, ok := more[i+1].(string); !ok {
panic("FilterByQueries scope only support DBOperationFlag and string pairs")
} else if len(v) != 0 {
policies[op] = v
}
}
i++
}
return func(tx *gorm.DB) *gorm.DB {
if tx.Statement.Context == nil {
panic("FilterByQueries scope is used without context")
}
ctx := tx.Statement.Context
existing, ok := ctx.Value(ckFilterQueries{}).(map[DBOperationFlag]string)
if !ok {
existing = map[DBOperationFlag]string{}
ctx = context.WithValue(ctx, ckFilterQueries{}, existing)
}
for flag, p := range policies {
if len(p) != 0 {
existing[flag] = p
}
}
tx.Statement.Context = ctx
return tx
}
}
// FilterWithExtraData is used as a scope for gorm.DB to provide extra key-value pairs as input during policy-based data filtering.
// The extra KV pairs are added under `input.resource`
// e.g. db.WithContext(ctx).Scopes(FilterWithExtraData("exception", "ignore_tenancy")).Find(...)
func FilterWithExtraData(kvs ...string) func(*gorm.DB) *gorm.DB {
return func(tx *gorm.DB) *gorm.DB {
if tx.Statement.Context == nil {
panic("FilterByQueries scope is used without context")
}
ctx := tx.Statement.Context
existing, ok := ctx.Value(ckFilterExtraData{}).(map[string]interface{})
if !ok {
existing = map[string]interface{}{}
ctx = context.WithValue(ctx, ckFilterExtraData{}, existing)
}
for i := range kvs {
if i + 1 < len(kvs) && len(kvs[i]) != 0 {
existing[kvs[i]] = kvs[i+1]
}
i++
}
tx.Statement.Context = ctx
return tx
}
}
/********************
Helpers
********************/
type ckFilterMode struct{}
type ckFilterQueries struct{}
type ckFilterExtraData struct{}
func shouldSkip(ctx context.Context, flag DBOperationFlag, fallback policyMode) bool {
if ctx == nil {
return defaultPolicyMode.hasFlags(flag)
}
switch v := ctx.Value(ckFilterMode{}).(type) {
case policyMode:
return !v.hasFlags(flag)
default:
return !fallback.hasFlags(flag)
}
}
func resolveQuery(ctx context.Context, flag DBOperationFlag, isPartial bool, meta *Metadata) string {
// ad-hoc info
if queries, ok := ctx.Value(ckFilterQueries{}).(map[DBOperationFlag]string); ok {
if q, ok := queries[flag]; ok && len(q) != 0 {
return finalizeQuery(q, isPartial)
}
}
// declarative info
pkg := meta.OPAPackage
var policy string
if p, ok := meta.Policies[flag]; ok && p != TagValueIgnore {
policy = p
}
// fallbacks
switch {
case len(pkg) == 0 && len(policy) == 0:
// everything default
return ""
case len(pkg) == 0:
pkg = fmt.Sprintf("%s.%s", opa.PackagePrefixResource, meta.ResType)
case len(policy) == 0:
if isPartial {
policy = fmt.Sprintf(DefaultPartialQueryTemplate, flagToResOp(flag))
} else {
policy = fmt.Sprintf(DefaultQueryTemplate, flagToResOp(flag))
}
}
return finalizeQuery(fmt.Sprintf("data.%s.%s", pkg, policy), isPartial)
}
func populateExtraData(ctx context.Context, input map[string]interface{}) {
extra, ok := ctx.Value(ckFilterExtraData{}).(map[string]interface{})
if !ok {
return
}
for k, v := range extra {
input[k] = v
}
}
func flagToResOp(flag DBOperationFlag) opa.ResourceOperation {
switch flag {
case DBOperationFlagRead:
return opa.OpRead
case DBOperationFlagUpdate:
return opa.OpWrite
case DBOperationFlagDelete:
return opa.OpDelete
default:
return opa.OpCreate
}
}
func finalizeQuery(query string, isPartial bool) string {
if isPartial {
query = strings.ReplaceAll(query, "/", ".")
if !strings.HasPrefix(query, "data.") {
query = "data." + query
}
} else {
query = strings.ReplaceAll(query, ".", "/")
query = strings.TrimPrefix(query, "data/")
}
return query
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opadata
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/data/types"
"github.com/cisco-open/go-lanai/pkg/opa"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"sync"
)
/****************************
Types
****************************/
// FilteredModel is a marker type that can be used in model struct as Embedded Struct.
// It's responsible for automatically applying OPA policy-based data filtering on model fields with "opa" tag.
//
// FilteredModel uses following GORM interfaces to modify PostgreSQL statements during select/update/delete,
// and apply value checks during create/update:
//
// - schema.QueryClausesInterface
// - schema.UpdateClausesInterface
// - schema.DeleteClausesInterface
// - schema.CreateClausesInterface
//
// When FilteredModel is present in data model, any model's fields tagged with "opa" will be used by OPA engine as following:
//
// - During "create", values are included with path "input.resource.<opa_field_name>"
// - During "update", values are included with path "input.resources.delta.<opa_field_name>"
// - During "select/update/delete", "input.resources.<opa_field_name>" is used as "unknowns" during OPA Partial Evaluation,
// and the result is translated to "WHERE" clause in PostgreSQL
//
// Where "<opa_field_name>" is specified by "opa" tag as `opa:"field:<opa_field_name>"`
//
// # Usage:
//
// FilteredModel is used as Embedded Struct in model struct,
// - "opa" tag is required with resource type defined:
// `opa:"type:<opa_res_type>"`
// - "gorm" tag should not be applied to the embedded struct
//
// # Examples:
//
// type Model struct {
// ID uuid.UUID `gorm:"primaryKey;type:uuid;default:gen_random_uuid();"`
// Value string
// TenantID uuid.UUID `gorm:"type:KeyID;not null" opa:"field:tenant_id"`
// TenantPath pqx.UUIDArray `gorm:"type:uuid[];index:,type:gin;not null" opa:"field:tenant_path"`
// OwnerID uuid.UUID `gorm:"type:KeyID;not null" opa:"field:owner_id"`
// opadata.FilteredModel `opa:"type:my_resource"`
// }
//
// Note: OPA filtering on relationships are currently not supported
//
// # Supported Tags:
//
// OPA tag should be in format of:
// `opa:"<key>:<value,<key>:<value>,..."`
// Invalid format or use of unsupported tag keys will result schema parsing error.
//
// Supported tag keys are:
// - "field:<opa_input_field_name>": required on any data field in model, only applicable on data fields
// - "input:<opa_input_field_name>": "input" is an alias of "field", only applicable on data fields
// - "type:<opa_resource_type>": required on FilteredModel. Ignored on other fields.
// This value will be used as prefix/package of OPA policy: e.g. "<opa_resource_type>/<policy_name>"
//
// Following keys can override CRUD policies and only applicable on FilteredModel:
//
// + "create:<policy_name>": optional, override policy used in OPA during create.
// + "read:<policy_name>": optional, override policy used in OPA during read.
// + "update:<policy_name>": optional, override policy used in OPA during update.
// + "delete:<policy_name>": optional, override policy used in OPA during delete.
// + "package:<policy_package>": optional, override policy's package. Default is "resource.<opa_resource_type>"
//
// Note: When <policy_name> is "-", policy-based data filtering is disabled for that operation.
// The default values are "filter_<op>"
type FilteredModel struct{
PolicyFilter policyFilter `gorm:"-"`
}
// Filter is a marker type that can be used in model struct as Struct Field.
// It's responsible for automatically applying OPA policy-based data filtering on model fields with "opa" tag.
//
// Filter uses following GORM interfaces to modify PostgreSQL statements during select/update/delete,
// and apply value checks during create/update:
//
// - schema.QueryClausesInterface
// - schema.UpdateClausesInterface
// - schema.DeleteClausesInterface
// - schema.CreateClausesInterface
//
// When Filter is present in data model, any model's fields tagged with "opa" will be used by OPA engine as following:
//
// - During "create", values are included with path "input.resource.<opa_field_name>"
// - During "update", values are included with path "input.resources.delta.<opa_field_name>"
// - During "select/update/delete", "input.resources.<opa_field_name>" is used as "unknowns" during OPA Partial Evaluation,
// and the result is translated to "WHERE" clause in PostgreSQL
//
// Where "<opa_field_name>" is specified by "opa" tag as `opa:"field:<opa_field_name>"`
//
// # Usage:
//
// Filter is used as type of Struct Field within model struct:
// - "opa" tag is required on the field with resource type defined:
// `opa:"type:<opa_res_type>"`
// - `gorm:"-"` is required
// - the field need to be exported
//
// # Examples:
//
// type Model struct {
// ID uuid.UUID `gorm:"primaryKey;type:uuid;default:gen_random_uuid();"`
// Value string
// OwnerName string
// OwnerID uuid.UUID `gorm:"type:KeyID;not null" opa:"field:owner_id"`
// Sharing constraints.Sharing `opa:"field:sharing"`
// OPAFilter opadata.Filter `gorm:"-" opa:"type:my_resource"`
// }
//
// Note: OPA filtering on relationships are currently not supported
//
// # Supported Tags:
//
// OPA tag should be in format of:
// `opa:"<key>:<value,<key>:<value>,..."`
// Invalid format or use of unsupported tag keys will result schema parsing error.
//
// Supported tag keys are:
// - "field:<opa_input_field_name>": required on any data field in model, only applicable on data fields
// - "input:<opa_input_field_name>": "input" is an alias of "field", only applicable on data fields
// - "type:<opa_resource_type>": required on FilteredModel. Ignored on other fields.
// This value will be used as prefix/package of OPA policy: e.g. "<opa_resource_type>/<policy_name>"
//
// Following keys can override CRUD policies and only applicable on FilteredModel:
//
// + "create:<policy_name>": optional, override policy used in OPA during create.
// + "read:<policy_name>": optional, override policy used in OPA during read.
// + "update:<policy_name>": optional, override policy used in OPA during update.
// + "delete:<policy_name>": optional, override policy used in OPA during delete.
// + "package:<policy_package>": optional, override policy's package. Default is "resource.<opa_resource_type>"
//
// Note: When <policy_name> is "-", policy-based data filtering is disabled for that operation.
// The default values are "filter_<op>"
type Filter struct{
policyFilter
}
/****************************
Policy Filter
****************************/
// policyFilter implements
// - schema.GormDataTypeInterface
// - schema.QueryClausesInterface
// - schema.UpdateClausesInterface
// - schema.DeleteClausesInterface
// - schema.CreateClausesInterface
// this data type adds "WHERE" clause for OPA policy filtering
// Note: policyFilter should be used in model struct as a named field with `gorm:"-"` tag
type policyFilter struct{}
// QueryClauses implements schema.QueryClausesInterface,
func (pf policyFilter) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{newStatementModifier(f, DBOperationFlagRead)}
}
// UpdateClauses implements schema.UpdateClausesInterface,
func (pf policyFilter) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{newStatementModifier(f, DBOperationFlagUpdate)}
}
// DeleteClauses implements schema.DeleteClausesInterface,
func (pf policyFilter) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{newStatementModifier(f, DBOperationFlagDelete)}
}
// CreateClauses implements schema.CreateClausesInterface,
func (pf policyFilter) CreateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{newStatementModifier(f, DBOperationFlagCreate)}
}
/***************************
Read, Delete
***************************/
// statementModifier implements clause.Interface and gorm.StatementModifier, where gorm.StatementModifier do the real work.
// See gorm.DeletedAt for impl. reference
type statementModifier struct {
types.NoopStatementModifier
Metadata
initOnce sync.Once
Schema *schema.Schema
Flag DBOperationFlag
OPAFilterOptionsFunc func(stmt *gorm.Statement) (opa.ResourceFilterOptions, error)
}
func newStatementModifier(f *schema.Field, flag DBOperationFlag) clause.Interface {
switch flag {
case DBOperationFlagCreate:
return newCreateStatementModifier(f.Schema)
case DBOperationFlagUpdate:
return newUpdateStatementModifier(f.Schema)
default:
ret := &statementModifier{
Schema: f.Schema,
Flag: flag,
}
ret.OPAFilterOptionsFunc = ret.opaFilterOptions
return ret
}
}
func (m *statementModifier) lazyInit() (err error) {
m.initOnce.Do(func() {
if ptr, e := loadMetadata(m.Schema); e != nil {
err = data.NewDataError(data.ErrorCodeInvalidApiUsage, e)
} else {
m.Metadata = *ptr
}
})
return
}
func (m *statementModifier) ModifyStatement(stmt *gorm.Statement) {
if stmt.AddError(m.lazyInit()) != nil {
return
}
if shouldSkip(stmt.Context, m.Flag, m.mode) {
return
}
filterOpts, e := m.OPAFilterOptionsFunc(stmt)
if e != nil {
_ = stmt.AddError(data.NewDataError(data.ErrorCodeInvalidApiUsage, fmt.Sprintf(`OPA filtering failed with error: %v`, e), e))
return
}
rs, e := opa.FilterResource(stmt.Context, m.ResType, flagToResOp(m.Flag), filterOpts)
if e != nil {
switch {
case errors.Is(e, opa.ErrQueriesNotResolved):
_ = stmt.AddError(opa.ErrAccessDenied.WithMessage("record not found"))
default:
_ = stmt.AddError(data.NewInternalError(fmt.Sprintf(`OPA filtering failed with error: %v`, e), e))
}
return
}
exprs := rs.Result.([]clause.Expression)
if len(exprs) == 0 {
return
}
// special fix for db.Model(&policyTarget{}).Where(&policyTarget{f1:v1}).Or(&policyTarget{f2:v2})...
// Ref: https://github.com/go-gorm/gorm/issues/3627
// https://github.com/go-gorm/gorm/commit/9b2181199d88ed6f74650d73fa9d20264dd134c0#diff-e3e9193af67f3a706b3fe042a9f121d3609721da110f6a585cdb1d1660fd5a3c
types.FixWhereClausesForStatementModifier(stmt)
if len(exprs) == 1 {
stmt.AddClause(clause.Where{Exprs: exprs})
} else {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(exprs...)}})
}
}
func (m *statementModifier) opaFilterOptions(stmt *gorm.Statement) (opa.ResourceFilterOptions, error) {
unknowns := make([]string, 0, len(m.Fields))
for k := range m.Fields {
unknown := fmt.Sprintf(`%s.%s.%s`, opa.InputPrefixRoot, opa.InputPrefixResource, k)
unknowns = append(unknowns, unknown)
}
return func(rf *opa.ResourceFilter) {
rf.Query = resolveQuery(stmt.Context, m.Flag, true, &m.Metadata)
rf.Unknowns = unknowns
rf.QueryMapper = NewGormPartialQueryMapper(&GormMapperConfig{
Metadata: &m.Metadata,
Fields: m.Fields,
Statement: stmt,
})
populateExtraData(stmt.Context, rf.ExtraData)
}, nil
}
/***************************
Update
***************************/
// updateStatementModifier is a special statementModifier that TODO
type updateStatementModifier struct {
statementModifier
}
func newUpdateStatementModifier(s *schema.Schema) *updateStatementModifier {
ret := &updateStatementModifier{
statementModifier{
Schema: s,
Flag: DBOperationFlagUpdate,
},
}
ret.OPAFilterOptionsFunc = ret.opaFilterOptions
return ret
}
func (m *updateStatementModifier) opaFilterOptions(stmt *gorm.Statement) (opa.ResourceFilterOptions, error) {
opts, e := m.statementModifier.opaFilterOptions(stmt)
if e != nil {
return nil, e
}
models, e := resolvePolicyTargets(stmt, &m.Metadata, m.Flag)
if e != nil {
return nil, ErrUnsupportedUsage.WithMessage("failed resolve delta in 'update' DB operation: %v", e)
}
switch len(models) {
case 1:
break
case 0:
return nil, ErrUnsupportedUsage.WithMessage("unable to resolve delta in 'update' DB operation")
default:
return nil, ErrUnsupportedUsage.WithMessage("'update' DB operation with batch update is not supported")
}
values, e := models[0].toResourceValues()
if e != nil {
return opts, ErrUnsupportedUsage.WithMessage(`%v`, e)
}
return func(rf *opa.ResourceFilter) {
opts(rf)
rf.Delta = values
}, nil
}
/***************************
Create
***************************/
// createStatementModifier is a special statementModifier that perform OPA policy check on resource creation
// Note: this modifier doesn't actually modify statement, it checks the to-be-created model/map against OPA and
// returns error if not allowed
type createStatementModifier struct {
statementModifier
}
func newCreateStatementModifier(s *schema.Schema) *createStatementModifier {
return &createStatementModifier{
statementModifier{
Schema: s,
Flag: DBOperationFlagCreate,
},
}
}
func (m *createStatementModifier) ModifyStatement(stmt *gorm.Statement) {
if stmt.AddError(m.lazyInit()) != nil {
return
}
if shouldSkip(stmt.Context, DBOperationFlagCreate, m.mode) {
return
}
models, e := resolvePolicyTargets(stmt, &m.Metadata, m.Flag)
if stmt.Statement.AddError(e) != nil {
return
}
for i := range models {
if stmt.Statement.AddError(m.checkPolicy(stmt.Context, &models[i])) != nil {
return
}
}
}
func (m *createStatementModifier) checkPolicy(ctx context.Context, model *policyTarget) error {
values, e := model.toResourceValues()
if e != nil {
return opa.ErrAccessDenied.WithMessage(`Cannot resolve values for model creation`)
}
return opa.AllowResource(ctx, model.meta.ResType, opa.OpCreate, func(res *opa.ResourceQuery) {
res.ResourceValues = *values
res.Policy = resolveQuery(ctx, m.Flag, false, &m.Metadata)
populateExtraData(ctx, res.ExtraData)
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opadata
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/opa/regoexpr"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/sdk"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"reflect"
"strings"
"time"
)
var (
typeScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
colRefPrefix = ast.Ref{ast.VarTerm("input"), ast.StringTerm("resource")}
)
const (
dataTypeJSONB = "jsonb"
)
type GormMapperConfig struct {
Metadata *Metadata
Fields map[string]*TaggedField
Statement *gorm.Statement
}
type GormPartialQueryMapper struct {
ctx context.Context
metadata *Metadata
fields map[string]*TaggedField
stmt *gorm.Statement
}
func NewGormPartialQueryMapper(cfg *GormMapperConfig) *GormPartialQueryMapper {
return &GormPartialQueryMapper{
ctx: context.Background(),
metadata: cfg.Metadata,
fields: cfg.Fields,
stmt: cfg.Statement,
}
}
/*****************************
ContextAware
*****************************/
func (m *GormPartialQueryMapper) WithContext(ctx context.Context) sdk.PartialQueryMapper {
mapper := *m
mapper.ctx = ctx
return &mapper
}
func (m *GormPartialQueryMapper) Context() context.Context {
return m.ctx
}
/*****************************
sdk.PartialQueryMapper
*****************************/
func (m *GormPartialQueryMapper) MapResults(pq *rego.PartialQueries) (interface{}, error) {
return regoexpr.TranslatePartialQueries(m.ctx, pq, func(opts *regoexpr.TranslateOption[clause.Expression]) {
opts.Translator = m
})
}
func (m *GormPartialQueryMapper) ResultToJSON(result interface{}) (interface{}, error) {
data, e := json.Marshal(result)
return string(data), e
}
/****************
Translator
****************/
func (m *GormPartialQueryMapper) Negate(_ context.Context, expr clause.Expression) clause.Expression {
return clause.Not(expr)
}
func (m *GormPartialQueryMapper) And(_ context.Context, exprs ...clause.Expression) clause.Expression {
return clause.And(exprs...)
}
func (m *GormPartialQueryMapper) Or(_ context.Context, exprs ...clause.Expression) clause.Expression {
return clause.Or(exprs...)
}
func (m *GormPartialQueryMapper) Comparison(ctx context.Context, op ast.Ref, colRef ast.Ref, val interface{}) (ret clause.Expression, err error) {
field, path, e := m.ResolveField(ctx, colRef)
if e != nil {
return nil, e
}
colExpr := m.ResolveColumnExpr(ctx, field, path...)
val = m.ResolveValueExpr(ctx, val, field)
switch op.Hash() {
case regoexpr.OpHashEqual, regoexpr.OpHashEq:
ret = &clause.Eq{Column: colExpr, Value: val}
case regoexpr.OpHashNeq:
ret = &clause.Neq{Column: colExpr, Value: val}
case regoexpr.OpHashLte:
ret = &clause.Lte{Column: colExpr, Value: val}
case regoexpr.OpHashLt:
ret = &clause.Lt{Column: colExpr, Value: val}
case regoexpr.OpHashGte:
ret = &clause.Gte{Column: colExpr, Value: val}
case regoexpr.OpHashGt:
ret = &clause.Gt{Column: colExpr, Value: val}
case regoexpr.OpHashIn:
expr := fmt.Sprintf("%s @> ?", colExpr)
ret = clause.Expr{
SQL: expr,
Vars: []interface{}{val},
}
default:
return nil, ErrQueryTranslation.WithMessage("Unsupported Rego operator: %v", op)
}
return
}
/****************
Helpers
****************/
func (m *GormPartialQueryMapper) Quote(field interface{}) string {
return m.stmt.Quote(field)
}
func (m *GormPartialQueryMapper) ResolveField(_ context.Context, colRef ast.Ref) (ret *TaggedField, jsonbPath []string, err error) {
// TODO review this part
if !colRef.HasPrefix(colRefPrefix) {
return nil, nil, ErrQueryTranslation.WithMessage(`OPA unknowns [%v] is missing prefix "%v"`, colRef, colRefPrefix)
}
var field *TaggedField
var key string
var remaining []string
for _, term := range colRef[len(colRefPrefix):] {
var str string
if e := ast.As(term.Value, &str); e != nil {
return nil, nil, ErrQueryTranslation.WithMessage(`OPA unknowns [%v] contains invalid term [%v]`, colRef, term)
}
if field == nil {
key = key + "." + str
if key[0] == '.' {
key = key[1:]
}
field, _ = m.fields[key]
} else {
remaining = append(remaining, str)
}
}
if field == nil {
return nil, nil, ErrQueryTranslation.WithMessage(`unable to resolve column with OPA unknowns [%v]`, colRef)
}
if len(remaining) != 0 && strings.ToLower(string(field.DataType)) != dataTypeJSONB {
return nil, nil, ErrQueryTranslation.WithMessage(`unable to resolve column with OPA unknowns [%v]: found field "%s" but it's not JSONB`, colRef, field.Name)
}
return field, remaining, nil
}
// ResolveColumnExpr resolve column clause with given field and optional JSONB path
func (m *GormPartialQueryMapper) ResolveColumnExpr(_ context.Context, field *TaggedField, paths...string) string {
col := clause.Column{
Table: field.Schema.Table,
Name: field.DBName,
}
if len(paths) == 0 {
return m.Quote(col)
}
// with remaining paths, the field is JSONB
expr := m.Quote(col)
for _, path := range paths {
expr = fmt.Sprintf(`%s -> '%s'`, expr, path)
}
return expr
}
func (m *GormPartialQueryMapper) ResolveValueExpr(_ context.Context, val interface{}, field *TaggedField) interface{} {
rv := reflect.ValueOf(val)
// try convert using field's type
if v, ok := m.resolveValueByType(rv, field.FieldType); ok {
return v.Interface()
}
// fallback to presenting value to DB recognizable pattern based on data type
if v, ok := m.resolveValueByDataType(rv, field.DataType); ok {
return v.Interface()
}
return val
}
// resolveValueByType convert given src to DB recognizable value of given type hint. e.g. []string to pqx.UUIDArray.
// In case the type hint is potential reference or container of source value, the source value is converted and wrapped
// using type hint.
// e.g. pqx.UUIDArray is a potential container of string, the string is converted to a pqx.UUIDArray with single element
// e.g. uuid.UUID is a potential reference of string, the string is converted to a pointer to uuid.UUID
// Note: This function guarantees that the returned value is same type of given type hint
func (m *GormPartialQueryMapper) resolveValueByType(src reflect.Value, typeHint reflect.Type) (reflect.Value, bool) {
// first, try convert directly, or via sql.Scanner API
if resolved, ok := m.toType(src, typeHint); ok {
return resolved, true
}
// second, if it's slice, array or pointer, try to convert given value to its Elem()
var resolved reflect.Value
kind := typeHint.Kind()
//nolint:exhaustive // we only handle slice, array or pointer
switch kind {
case reflect.Slice, reflect.Array, reflect.Pointer:
v, ok := m.resolveValueByType(src, typeHint.Elem())
if !ok {
return resolved, false
}
resolved = v
// wrap resolved value into proper type
//nolint:exhaustive // we only handle slice, array or pointer
switch kind {
case reflect.Slice:
ret := reflect.MakeSlice(typeHint, 1, 1)
ret.Index(0).Set(resolved)
return ret, true
case reflect.Pointer:
if resolved.CanAddr() {
return resolved.Addr(), true
}
return src, false
case reflect.Array:
ret := reflect.New(typeHint).Elem()
if typeHint.Len() > 0 {
ret.Index(0).Set(resolved)
}
return ret, true
}
}
return src, false
}
// resolveValueByDataType try to present value to DB recognizable pattern based on data type.
// Note: we only support minimum set of data types.
func (m *GormPartialQueryMapper) resolveValueByDataType(src reflect.Value, dataType schema.DataType) (reflect.Value, bool) {
switch strings.ToLower(string(dataType)) {
case "jsonb":
if data, e := json.Marshal(src.Interface()); e == nil {
return reflect.ValueOf(string(data)), true
}
case string(schema.Time):
if intValue, ok := m.toType(src, reflect.TypeOf(int64(0))); ok {
// treat value as timestamp in seconds
t := time.Unix(intValue.Int(), 0)
return reflect.ValueOf(t), true
}
if src.Kind() == reflect.String {
if t := utils.ParseTimeISO8601(src.String()); !t.IsZero() {
return reflect.ValueOf(t), true
}
}
}
return src, false
}
// toType convert source value to given type using direct convert if it's scalar, string, alias, etc.,
// or using sql.Scanner interface
func (m *GormPartialQueryMapper) toType(src reflect.Value, typ reflect.Type) (reflect.Value, bool) {
switch {
case src.CanConvert(typ):
return src.Convert(typ), true
case typ.Implements(typeScanner):
v := reflect.New(typ).Elem()
if e := v.Interface().(sql.Scanner).Scan(src.Interface()); e == nil {
return v, true
}
case reflect.PointerTo(typ).Implements(typeScanner):
v := reflect.New(typ)
if e := v.Interface().(sql.Scanner).Scan(src.Interface()); e == nil {
return v.Elem(), true
}
}
return src, false
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opadata
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/opa"
"gorm.io/gorm"
"reflect"
)
/*********************
Model Resolver
*********************/
// policyTarget collected information about current policy target.
// The target could be a model struct instance, pointer, map of key-value pairs, etc.
type policyTarget struct {
meta *Metadata
modelPtr reflect.Value
modelValue reflect.Value
model interface{}
valueMap map[string]interface{}
}
// toResourceValues convert to opa.ResourceValues
// might return nil without error if there is no recognized changes
func (m policyTarget) toResourceValues() (*opa.ResourceValues, error) {
input := map[string]interface{}{}
switch {
case m.modelValue.IsValid():
// create by model struct
for k, tagged := range m.meta.Fields {
rv := m.modelValue.FieldByIndex(tagged.StructField.Index)
if rv.IsValid() && !rv.IsZero() {
input[k] = rv.Interface()
}
}
case m.valueMap != nil:
// create by model map
for k, tagged := range m.meta.Fields {
v, _ := m.valueMap[tagged.Name]
if v == nil {
v, _ = m.valueMap[tagged.DBName]
}
if v != nil && !reflect.ValueOf(v).IsZero() {
input[k] = v
}
}
default:
return nil, ErrUnsupportedUsage.WithMessage(`Cannot resolve values of model`)
}
if len(input) == 0 {
return nil, nil
}
return &opa.ResourceValues{
ExtraData: input,
}, nil
}
// resolvePolicyTargets resolve to be created/updated/read/deleted model values.
// depending on the operation and GORM usage, values may be extracted from Dest or ReflectValue and the extracted values
// could be struct or map
func resolvePolicyTargets(stmt *gorm.Statement, meta *Metadata, op DBOperationFlag) ([]policyTarget, error) {
resolved := make([]policyTarget, 0, 5)
fn := func(v reflect.Value) error {
model := policyTarget{
meta: meta,
model: v.Interface(),
}
switch {
case v.Type() == reflect.PointerTo(stmt.Schema.ModelType):
model.modelPtr = v
model.modelValue = v.Elem()
case v.Type() == typeGenericMap:
model.valueMap = v.Convert(typeGenericMap).Interface().(map[string]interface{})
default:
return fmt.Errorf("unsupported dest model [%T]", v.Interface())
}
resolved = append(resolved, model)
return nil
}
var e error
switch op {
case DBOperationFlagUpdate:
// for update, Statement.Dest should be used instead of Statement.ReflectValue.
// See callbacks.SetupUpdateReflectValue() (update.go)
e = walkDest(stmt, fn)
default:
e = walkReflectValue(stmt, fn)
}
if e != nil {
return nil, fmt.Errorf("unable to extract current model model: %v", e)
}
return resolved, nil
}
// walkDest is similar to callbacks.callMethod. It walkthrough statement's ReflectValue
// and call given function with the pointer of the model.
func walkDest(stmt *gorm.Statement, fn func(value reflect.Value) error) (err error) {
rv := reflect.ValueOf(stmt.Dest)
for rv.Kind() == reflect.Pointer {
rv = rv.Elem()
}
return walkValues(rv, fn)
}
// walkReflectValue is similar to callbacks.callMethod. It walkthrough statement's ReflectValue
// and call given function with the pointer of the model.
func walkReflectValue(stmt *gorm.Statement, fn func(value reflect.Value) error) (err error) {
return walkValues(stmt.ReflectValue, fn)
}
// walkValues recursively walk give model, support slice, array, struct and map
func walkValues(rv reflect.Value, fn func(value reflect.Value) error) error {
//nolint:exhaustive // we only deal with map, struct and slice
switch rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
elem := rv.Index(i)
for elem.Kind() == reflect.Pointer {
elem = elem.Elem()
}
if e := walkValues(elem, fn); e != nil {
return e
}
}
case reflect.Struct:
if !rv.CanAddr() {
return gorm.ErrInvalidValue
}
return fn(rv.Addr())
case reflect.Map:
return fn(rv)
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
opalogging "github.com/open-policy-agent/opa/logging"
"github.com/open-policy-agent/opa/plugins"
"github.com/open-policy-agent/opa/sdk"
"io"
)
var embeddedOPA struct {
opa *sdk.OPA
ready EmbeddedOPAReadyCH
inputCustomizers []InputCustomizer
}
type EmbeddedOPAReadyCH <-chan struct{}
func EmbeddedOPA() *sdk.OPA {
return embeddedOPA.opa
}
type EmbeddedOPAOptions func(opts *EmbeddedOPAOption)
type EmbeddedOPAOption struct {
// SDKOptions raw sdk.Options
SDKOptions sdk.Options
// Config struct overrides SDKOptions.Config
Config *Config
// InputCustomizers installed as global input customizers for any OPA queries
InputCustomizers []InputCustomizer
// Properties for extra configuration that not included in Config
Properties *Properties
}
func WithConfig(cfg *Config) EmbeddedOPAOptions {
return func(opts *EmbeddedOPAOption) {
opts.Config = cfg
}
}
func WithRawConfig(jsonReader io.Reader) EmbeddedOPAOptions {
return func(opts *EmbeddedOPAOption) {
opts.SDKOptions.Config = jsonReader
}
}
func WithLogger(logger opalogging.Logger) EmbeddedOPAOptions {
return func(opts *EmbeddedOPAOption) {
opts.SDKOptions.Logger = logger
}
}
func WithLogLevel(level log.LoggingLevel) EmbeddedOPAOptions {
return func(opts *EmbeddedOPAOption) {
opts.SDKOptions.Logger = NewOPALogger(logger, level)
}
}
func WithInputCustomizers(customizers ...InputCustomizer) EmbeddedOPAOptions {
return func(opts *EmbeddedOPAOption) {
opts.InputCustomizers = customizers
}
}
func WithProperties(props Properties) EmbeddedOPAOptions {
return func(opts *EmbeddedOPAOption) {
opts.Properties = &props
}
}
// NewEmbeddedOPA create a new sdk.OPA instance and make it available via EmbeddedOPA function.
// Caller is responsible to call (*sdk.OPA).Stop to release resources
func NewEmbeddedOPA(ctx context.Context, opts ...EmbeddedOPAOptions) (*sdk.OPA, EmbeddedOPAReadyCH, error) {
readyCh := make(chan struct{}, 1)
opt := EmbeddedOPAOption{
SDKOptions: sdk.Options{
ID: `Embedded-OPA`,
Ready: readyCh,
Plugins: map[string]plugins.Factory{
pluginNameDecisionLogger: decisionLogPluginFactory{},
},
},
}
for _, fn := range opts {
fn(&opt)
}
if e := validateOptions(ctx, &opt); e != nil {
return nil, nil, e
}
opa, e := sdk.New(ctx, opt.SDKOptions)
if e != nil {
close(readyCh)
return nil, nil, fmt.Errorf("error when create embedded OPA: %v", e)
}
// set global variable
embeddedOPA.opa = opa
embeddedOPA.ready = readyCh
embeddedOPA.inputCustomizers = opt.InputCustomizers
return opa, readyCh, nil
}
func validateOptions(ctx context.Context, opt *EmbeddedOPAOption) error {
// check logger
if opt.SDKOptions.Logger == nil {
opaLog := NewOPALogger(logger.WithContext(ctx), log.LevelInfo)
WithLogger(opaLog)(opt)
} else if v, ok := opt.SDKOptions.Logger.(*opaLogger); ok {
WithLogger(v.WithContext(ctx))(opt)
}
// check config
switch {
case opt.Config == nil && opt.SDKOptions.Config == nil:
return fmt.Errorf(`"Config" is missing`)
case opt.Config != nil:
reader, e := opt.Config.JSONReader(ctx)
if e != nil {
return e
}
WithRawConfig(reader)(opt)
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"errors"
"fmt"
)
var (
ErrInternal = NewError("internal error")
ErrAccessDenied = NewError("Access Denied")
ErrQueriesNotResolved = NewError(`OPA cannot resolve partial queries`)
)
var errorCode int
type Error struct {
code int
msg string
}
func (e Error) Error() string {
return e.msg
}
func (e Error) Is(err error) bool {
var v Error
return errors.As(err, &v) && v.code == e.code
}
func (e Error) WithMessage(tmpl string, args ...interface{}) Error {
return Error{
code: e.code,
msg: fmt.Sprintf(tmpl, args...),
}
}
func NewError(tmpl string, args ...interface{}) Error {
errorCode++
return Error{
code: errorCode,
msg: fmt.Sprintf(tmpl, args...),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opainit
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"github.com/cisco-open/go-lanai/pkg/opa"
"go.uber.org/fx"
)
type regDI struct {
fx.In
HealthRegistrar health.Registrar `optional:"true"`
OPAReady opa.EmbeddedOPAReadyCH
}
func RegisterHealth(di regDI) {
if di.HealthRegistrar == nil {
return
}
di.HealthRegistrar.MustRegister(&HealthIndicator{
ready: di.OPAReady,
})
}
type HealthIndicator struct {
ready opa.EmbeddedOPAReadyCH
}
func (i *HealthIndicator) Name() string {
return "opa"
}
func (i *HealthIndicator) Health(_ context.Context, _ health.Options) health.Health {
select {
case <-i.ready:
return health.NewDetailedHealth(health.StatusUp, "OPA engine is UP", nil)
default:
return health.NewDetailedHealth(health.StatusDown, "OPA engine is not ready", nil)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opainit
import (
"context"
"embed"
"fmt"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/opa"
opainput "github.com/cisco-open/go-lanai/pkg/opa/input"
"github.com/open-policy-agent/opa/sdk"
"go.uber.org/fx"
)
var logger = log.New("OPA.Init")
//go:embed defaults-opa.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Precedence: bootstrap.SecurityPrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(BindProperties, ProvideEmbeddedOPA),
fx.Invoke(InitializeEmbeddedOPA, RegisterHealth),
},
}
// Use Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
}
func BindProperties(ctx *bootstrap.ApplicationContext) opa.Properties {
props := opa.NewProperties()
if e := ctx.Config().Bind(props, opa.PropertiesPrefix); e != nil {
panic(fmt.Errorf("failed to bind OPA properties: %v", e))
}
return *props
}
type EmbeddedOPAOut struct {
fx.Out
OPA *sdk.OPA
Ready opa.EmbeddedOPAReadyCH
}
type EmbeddedOPADI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Properties opa.Properties
Customizers []opa.ConfigCustomizer `group:"opa"`
}
func ProvideEmbeddedOPA(di EmbeddedOPADI) (EmbeddedOPAOut, error) {
cfg, e := opa.LoadConfig(di.AppCtx, di.Properties, di.Customizers...)
if e != nil {
return EmbeddedOPAOut{}, fmt.Errorf("unable to load OPA Config: %v", e)
}
embedded, ready, e := opa.NewEmbeddedOPA(di.AppCtx,
opa.WithConfig(cfg),
opa.WithLogLevel(di.Properties.Logging.LogLevel),
opa.WithInputCustomizers(opainput.DefaultInputCustomizers...),
)
if e != nil {
return EmbeddedOPAOut{}, e
}
return EmbeddedOPAOut{
OPA: embedded,
Ready: ready,
}, nil
}
func InitializeEmbeddedOPA(lc fx.Lifecycle, opa *sdk.OPA, ready opa.EmbeddedOPAReadyCH) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
go func() {
select {
case <-ready:
logger.WithContext(ctx).Infof("Embedded OPA is Ready")
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
opa.Stop(ctx)
return nil
},
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opainput
import (
"context"
securityinternal "github.com/cisco-open/go-lanai/internal/security"
"github.com/cisco-open/go-lanai/pkg/opa"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
)
func PopulateAuthenticationClause(ctx context.Context, input *opa.Input) error {
auth := security.Get(ctx)
if !security.IsFullyAuthenticated(auth) {
input.Authentication = nil
return nil
}
if input.Authentication == nil {
input.Authentication = opa.NewAuthenticationClause()
}
return populateAuthenticationClause(auth, input.Authentication)
}
func populateAuthenticationClause(auth security.Authentication, clause *opa.AuthenticationClause) error {
clause.Username = getUsernameOrEmpty(auth)
clause.Permissions = make([]string, 0, len(auth.Permissions()))
for k := range auth.Permissions() {
clause.Permissions = append(clause.Permissions, k)
}
switch v := auth.(type) {
case oauth2.Authentication:
clause.Client = &opa.OAuthClientClause{
ClientID: v.OAuth2Request().ClientId(),
GrantType: v.OAuth2Request().GrantType(),
Scopes: v.OAuth2Request().Scopes().Values(),
}
default:
}
details := auth.Details()
if v, ok := details.(security.UserDetails); ok {
clause.UserID = v.UserId()
}
if v, ok := details.(securityinternal.TenantAccessDetails); ok {
clause.AccessibleTenants = v.EffectiveAssignedTenantIds().Values()
}
if v, ok := details.(security.TenantDetails); ok {
clause.TenantID = v.TenantId()
}
if v, ok := details.(security.ProviderDetails); ok {
clause.ProviderID = v.ProviderId()
}
if v, ok := details.(security.AuthenticationDetails); ok {
clause.Roles = v.Roles().Values()
}
return nil
}
func getUsernameOrEmpty(auth security.Authentication) string {
username, e := security.GetUsername(auth)
if e != nil {
return ""
}
return username
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/open-policy-agent/opa/plugins"
opalogs "github.com/open-policy-agent/opa/plugins/logs"
"github.com/open-policy-agent/opa/rego"
)
var evtLogger = log.New("OPA.Event")
const (
pluginNameDecisionLogger = `lanai_logger`
kLogDecisionLog = `opa`
kLogDecisionReason = `reason`
kLogPartialResult = `result`
kLogPartialReason = `reason`
)
/*******************
Log Context
*******************/
type kLogCtx struct{}
var kLogCtxLevel = kLogCtx{}
type logContext struct {
context.Context
level log.LoggingLevel
}
func (c logContext) Value(key any) any {
switch key {
case kLogCtxLevel:
return c.level
}
return c.Context.Value(key)
}
func logContextWithLevel(parent context.Context, level log.LoggingLevel) context.Context {
return &logContext{
Context: parent,
level: level,
}
}
/*******************
Leveled Log
*******************/
// eventLogger get a logger with context and level properly configured
func eventLogger(ctx context.Context, defaultLevel log.LoggingLevel) log.Logger {
return evtLogger.WithContext(ctx).WithLevel(resolveLogLevel(ctx, defaultLevel))
}
func resolveLogLevel(ctx context.Context, defaultLevel log.LoggingLevel) log.LoggingLevel {
override, ok := ctx.Value(kLogCtxLevel).(log.LoggingLevel)
if !ok {
return defaultLevel
}
return override
}
/*******************
Decision Log
*******************/
type decisionLogPluginFactory struct{}
func (f decisionLogPluginFactory) Validate(_ *plugins.Manager, rawConfig []byte) (interface{}, error) {
var props LoggingProperties
if e := json.Unmarshal(rawConfig, &props); e != nil {
return nil, e
}
return props, nil
}
func (f decisionLogPluginFactory) New(manager *plugins.Manager, cfg interface{}) plugins.Plugin {
manager.UpdatePluginStatus(pluginNameDecisionLogger, &plugins.Status{
State: plugins.StateOK,
Message: fmt.Sprintf("Plugin is ready [%s]", pluginNameDecisionLogger),
})
return &decisionLogger{
level: cfg.(LoggingProperties).DecisionLogsLevel,
}
}
// decisionLogger OPA SDK decision logger plugin. Implementing "github.com/open-policy-agent/opa/plugins/logs".Logger
type decisionLogger struct {
level log.LoggingLevel
}
func (l *decisionLogger) Start(_ context.Context) error {
return nil
}
func (l *decisionLogger) Stop(_ context.Context) {
// does nothing
}
func (l *decisionLogger) Reconfigure(_ context.Context, cfg interface{}) {
l.level = cfg.(LoggingProperties).DecisionLogsLevel
}
func (l *decisionLogger) Log(ctx context.Context, v1 opalogs.EventV1) error {
eventLogger(ctx, l.level).
WithKV(kLogDecisionLog, decisionEvent{event: &v1}).
Printf("Decision Log")
return nil
}
/*******************
Events
*******************/
type decisionEvent struct {
event *opalogs.EventV1
}
func (de decisionEvent) String() string {
v, e := json.Marshal(de.event)
if e != nil {
return fmt.Sprintf("%v", de.event)
}
return string(v)
}
func (de decisionEvent) MarshalJSON() ([]byte, error) {
return json.Marshal(de.event)
}
type resultEvent struct {
ID string `json:"decision_id"`
Result interface{} `json:"result"`
Deny bool `json:"deny"`
}
func (re resultEvent) String() string {
return fmt.Sprintf("[%s]: %v", re.ID, re.Result)
}
type partialQueriesLog rego.PartialQueries
func (pq partialQueriesLog) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`"%v"`, rego.PartialQueries(pq))), nil
}
type partialResultEvent struct {
ID string `json:"decision_id"`
Err error `json:"error,omitempty"`
AST *partialQueriesLog `json:"queries,omitempty"`
}
func (pre partialResultEvent) String() string {
if pre.Err != nil {
return pre.Err.Error()
}
return fmt.Sprintf("[%s]: %v", pre.ID, pre.AST)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log"
opalogging "github.com/open-policy-agent/opa/logging"
)
var logger = log.New("OPA")
var (
logLevelMapper = map[opalogging.Level]log.LoggingLevel{
opalogging.Debug: log.LevelDebug,
opalogging.Info: log.LevelInfo,
opalogging.Warn: log.LevelWarn,
opalogging.Error: log.LevelError,
}
)
/*******************
OPA logger
*******************/
// opaLogger implement logging.Logger
type opaLogger struct {
logger log.Logger
level opalogging.Level
}
func NewOPALogger(logger log.Logger, lvl log.LoggingLevel) opalogging.Logger {
var level opalogging.Level
switch lvl {
case log.LevelDebug:
level = opalogging.Debug
case log.LevelWarn:
level = opalogging.Warn
case log.LevelError:
level = opalogging.Error
default:
level = opalogging.Info
}
return &opaLogger{
logger: logger.WithLevel(lvl),
level: level,
}
}
func (l *opaLogger) WithContext(ctx context.Context) *opaLogger {
return &opaLogger{
logger: logger.WithContext(ctx),
level: l.level,
}
}
func (l *opaLogger) Debug(fmt string, args ...interface{}) {
l.logger.Debugf(fmt, args...)
}
func (l *opaLogger) Info(fmt string, args ...interface{}) {
l.logger.Infof(fmt, args...)
}
func (l *opaLogger) Warn(fmt string, args ...interface{}) {
l.logger.Warnf(fmt, args...)
}
func (l *opaLogger) Error(fmt string, args ...interface{}) {
l.logger.Errorf(fmt, args...)
}
func (l *opaLogger) WithFields(fields map[string]interface{}) opalogging.Logger {
kvs := make([]interface{}, 0, 10)
for k, v := range fields {
kvs = append(kvs, k, v)
}
return &opaLogger{
logger: l.logger.WithKV(kvs...),
level: l.level,
}
}
func (l *opaLogger) GetLevel() opalogging.Level {
return l.level
}
func (l *opaLogger) SetLevel(lvl opalogging.Level) {
newLvl, ok := logLevelMapper[lvl]
if !ok {
return
}
l.logger = l.logger.WithLevel(newLvl)
l.level = lvl
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
)
const PropertiesPrefix = "security.opa"
type Properties struct {
Server BundleServerProperties `json:"server"`
Bundles map[string]BundleSourceProperties `json:"bundles"`
Logging LoggingProperties `json:"logging"`
}
type BundleServerProperties struct {
Name string `json:"name"`
URL string `json:"url"`
PollingProperties
}
type BundleSourceProperties struct {
Path string `json:"path"`
PollingProperties
}
type LoggingProperties struct {
LogLevel log.LoggingLevel `json:"level"`
DecisionLogsLevel log.LoggingLevel `json:"decision-logs-level"`
}
type PollingProperties struct {
PollingMinDelay *utils.Duration `json:"polling-min-delay,omitempty"` // min amount of time to wait between successful poll attempts
PollingMaxDelay *utils.Duration `json:"polling-max-delay,omitempty"` // max amount of time to wait between poll attempts
LongPollingTimeout *utils.Duration `json:"long-polling-timeout,omitempty"` // max amount of time the server should wait before issuing a timeout if there's no update available
}
func NewProperties() *Properties {
return &Properties{}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package regoexpr
import (
"context"
"fmt"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"strings"
)
type NoopPartialQueryMapper struct{}
func (m NoopPartialQueryMapper) MapResults(pq *rego.PartialQueries) (interface{}, error) {
return TranslatePartialQueries(context.Background(), pq, func(opts *TranslateOption[string]) {
opts.Translator = NoopQueryTranslator{}
})
}
func (m NoopPartialQueryMapper) ResultToJSON(result interface{}) (interface{}, error) {
return result, nil
}
type NoopQueryTranslator struct{}
func (t NoopQueryTranslator) Negate(_ context.Context, expr string) string {
return fmt.Sprintf(`!%s`, expr)
}
func (t NoopQueryTranslator) And(_ context.Context, exprs ...string) string {
return strings.Join(exprs, " && ")
}
func (t NoopQueryTranslator) Or(_ context.Context, exprs ...string) string {
return strings.Join(exprs, " || ")
}
func (t NoopQueryTranslator) Comparison(_ context.Context, op ast.Ref, colRef ast.Ref, val interface{}) (string, error) {
return fmt.Sprintf(`%v %v %v`, colRef, op, val), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package regoexpr
import (
"context"
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/opa"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"reflect"
)
var logger = log.New("OPA.AST")
type TranslateOptions[EXPR any] func(opts *TranslateOption[EXPR])
type TranslateOption[EXPR any] struct {
Translator QueryTranslator[EXPR]
}
// TranslatePartialQueries translate OPA partial queries into other expression languages. e.g. Postgres expression
// Note:
// 1. When PartialQueries.Queries is empty, it means access is DENIED regardless any unknown values
// 2. When PartialQueries.Queries is not empty but contains nil body, it means access is GRANTED regardless any unknown values
func TranslatePartialQueries[EXPR any](ctx context.Context, pq *rego.PartialQueries, opts ...TranslateOptions[EXPR]) ([]EXPR, error) {
logger.WithContext(ctx).Debugf("Queries: %v", pq)
opt := TranslateOption[EXPR]{}
for _, fn := range opts {
fn(&opt)
}
if opt.Translator == nil {
return nil, ParsingError.WithMessage("query translator is nil")
}
// normalize
queries, changed := NormalizeQueries(ctx, pq.Queries)
if changed {
logger.WithContext(ctx).Debugf("Normalized Queries: %v", queries)
}
// If queries is nil, it means any unknowns can satisfy.
// However, if queries is empty, it means no unknowns would satisfy
switch {
case queries == nil:
return []EXPR{}, nil
case len(queries) == 0:
return nil, opa.ErrQueriesNotResolved
}
exprs := make([]EXPR, 0, len(queries))
for _, body := range queries {
logger.WithContext(ctx).Debugf("Parsing Query: %v", body)
ands := make([]EXPR, 0, 5)
for _, expr := range body {
if qExpr, e := TranslateExpression(ctx, expr, &opt); e != nil {
logger.WithContext(ctx).Debugf("%v", e)
return nil, e
} else if !reflect.ValueOf(qExpr).IsZero() {
ands = append(ands, qExpr)
}
}
exprs = append(exprs, opt.Translator.And(ctx, ands...))
}
return exprs, nil
}
// NormalizeQueries remove duplicate queries and duplicate expressions in each query
func NormalizeQueries(ctx context.Context, queries []ast.Body) (ret []ast.Body, changed bool) {
ret = make([]ast.Body, 0, len(queries))
bodyHash := map[int]struct{}{}
for _, body := range queries {
if body == nil {
// Because queries are "OR", if any query is nil, it means the entire queries always yield "true".
// This means OPA can conclude the requested policy query without any unknowns
return nil, true
}
// check duplicates
if _, ok := bodyHash[body.Hash()]; ok {
changed = true
continue
}
bodyHash[body.Hash()] = struct{}{}
// normalize body
if exprs, ok := NormalizeExpressions(ctx, body); len(exprs) != 0 {
ret = append(ret, exprs)
changed = changed || ok
} else {
changed = true
}
}
return
}
// NormalizeExpressions remove duplicate expressions in query
func NormalizeExpressions(ctx context.Context, body ast.Body) (exprs ast.Body, changed bool) {
exprs = make([]*ast.Expr, 0, len(body))
if HasControversialExpressions(ctx, body) {
logger.WithContext(ctx).Debugf("Controversial Query: %v", body)
return exprs, true
}
exprHash := map[int]struct{}{}
for _, expr := range body {
hash := calculateHash(expr)
if _, ok := exprHash[hash]; !ok {
exprHash[hash] = struct{}{}
exprs = append(exprs, expr)
} else {
changed = true
}
}
return
}
// HasControversialExpressions analyze given expression and return true if it contains controversial expressions:
// Examples:
// - "value1 = input.resource.field AND value2 = input.resource.field"
// - "value1 = input.resource.field AND value1 != input.resource.field"
func HasControversialExpressions(_ context.Context, body []*ast.Expr) (ret bool) {
equals := map[int]int{}
notEquals := map[int]utils.Set{}
for _, expr := range body {
ref, val, op, ok := resolveThreeTermsOp(expr)
if !ok || !ref.IsGround() {
continue
}
// only handle equalities
negate := false
switch {
case OpEqual.Equal(op) || OpEq.Equal(op):
negate = expr.Negated
case OpNeq.Equal(op):
negate = !expr.Negated
default:
continue
}
// compare values by hashes
rHash := ref.Hash()
vHash := val.Hash()
if negate {
if _, ok := notEquals[rHash]; !ok {
notEquals[rHash] = utils.NewSet()
}
notEquals[rHash].Add(vHash)
} else {
// var == v1 AND var == v2
if v, ok := equals[rHash]; ok && v != vHash {
return true
}
equals[rHash] = vHash
}
// var == v1 AND var != v1
if notEquals[rHash].Has(equals[rHash]) {
return true
}
}
return false
}
func TranslateExpression[EXPR any](ctx context.Context, astExpr *ast.Expr, opt *TranslateOption[EXPR]) (ret EXPR, err error) {
//logger.WithContext(ctx).Debugf("Expr: %v", astExpr)
//fmt.Printf("IsEquality = %v, IsGround = %v, IsCall = %v\n", astExpr.IsEquality(), astExpr.IsGround(), astExpr.IsCall())
switch {
case astExpr.OperatorTerm() != nil:
ret, err = TranslateOperationExpr(ctx, astExpr, opt)
default:
return ret, ParsingError.WithMessage("unsupported Rego expression: %v", astExpr)
}
return
}
func TranslateOperationExpr[EXPR any](ctx context.Context, astExpr *ast.Expr, opt *TranslateOption[EXPR]) (ret EXPR, err error) {
operands := astExpr.Operands()
switch len(operands) {
case 2:
ret, err = TranslateThreeTermsOp(ctx, astExpr, opt)
default:
err = ParsingError.WithMessage("unsupported Rego operation: %v", astExpr)
}
if err != nil {
return
}
if astExpr.Negated {
ret = opt.Translator.Negate(ctx, ret)
}
return
}
func TranslateThreeTermsOp[EXPR any](ctx context.Context, astExpr *ast.Expr, opt *TranslateOption[EXPR]) (ret EXPR, err error) {
// format "op(Ref, Value)", "Ref op Value"
var zero EXPR
ref, val, op, ok := resolveThreeTermsOp(astExpr)
switch {
case !ok:
return zero, ParsingError.WithMessage(`invalid Rego operation format: expected "op(Ref, Value)", but got %v(%v)`, astExpr.OperatorTerm(), astExpr.Operands())
case op.HasPrefix(OpInternal):
return zero, ParsingError.WithMessage(`unsupported Rego operator [%v]`, op)
}
// resolve value
value, e := ast.ValueToInterface(val, illegalResolver{})
if e != nil {
return zero, ParsingError.WithMessage(`unable to resolve Rego value [%v]: %v`, val, e)
}
if v, ok := value.(json.Number); ok {
if value, e = v.Float64(); e != nil {
return zero, ParsingError.WithMessage(`unable to resolve Rego value [%v] as number: %v`, val, e)
}
}
// resolve operator and column
if ref.IsGround() {
return opt.Translator.Comparison(ctx, op, ref, value)
}
ground := ref.GroundPrefix()
op = OpIn
return opt.Translator.Comparison(ctx, op, ground, value)
}
/**********************
Helpers
**********************/
// note: when we calculate hash, we don't want to consider its Index and non-ground part
func calculateHash(astExpr *ast.Expr) int {
expr := astExpr.Copy()
expr.Index = 0
return expr.Hash()
}
func resolveThreeTermsOp(astExpr *ast.Expr) (ref ast.Ref, val ast.Value, op ast.Ref, ok bool) {
// format "op(Ref, Value)", "Ref op Value"
op = astExpr.Operator()
operands := astExpr.Operands()
if op == nil || len(operands) != 2 {
return nil, nil, nil, false
}
for _, term := range operands {
switch v := term.Value.(type) {
case ast.Ref:
ref = v
default:
val = v
}
}
if ref == nil || val == nil {
return nil, nil, nil, false
}
return ref, val, op, true
}
type illegalResolver struct{}
func (illegalResolver) Resolve(ast.Ref) (interface{}, error) {
return nil, ParsingError.WithMessage("resolving Ref is not supported")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opatest
import (
"context"
"embed"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/opa"
opainit "github.com/cisco-open/go-lanai/pkg/opa/init"
opatestserver "github.com/cisco-open/go-lanai/pkg/opa/test/server"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"github.com/open-policy-agent/opa/plugins/bundle"
oparest "github.com/open-policy-agent/opa/plugins/rest"
sdktest "github.com/open-policy-agent/opa/sdk/test"
"go.uber.org/fx"
"io/fs"
)
//go:embed bundle/**
var DefaultBundleFS embed.FS
//go:embed test-defaults-opa.yml
var DefaultConfigFS embed.FS
const (
TestBundleName = `test-bundle`
TestBundlePathPrefix = `/bundles/`
BundleServiceKey = "test-bundle-service"
)
// WithBundles is a test.Options that initialize OPA and start a bundle server in test with given bundle FS.
// All FSs are built into a single bundle and loaded to OPA engine.
// If no bundle FS provided, DefaultBundleFS is used.
func WithBundles(bundleFSs ...fs.FS) test.Options {
return test.WithOptions(
apptest.WithModules(opainit.Module),
apptest.WithConfigFS(DefaultConfigFS),
apptest.WithFxOptions(
fx.Provide(BundleServerProvider(bundleFSs...)),
fx.Invoke(opatestserver.InitializeBundleServer),
fx.Invoke(WaitForOPAReady),
),
)
}
type BundleServerDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
}
type BundleServerOut struct {
fx.Out
Server *sdktest.Server
Customizer opa.ConfigCustomizer `group:"opa"`
}
func BundleServerProvider(bundleFSs ...fs.FS) func(BundleServerDI) (BundleServerOut, error) {
if len(bundleFSs) == 0 {
bundleFSs = []fs.FS{DefaultBundleFS}
}
return func(di BundleServerDI) (BundleServerOut, error) {
server, e := opatestserver.NewBundleServer(di.AppCtx,
opatestserver.WithBundleSources(bundleFSs...), opatestserver.WithBundleName(TestBundleName))
if e != nil {
return BundleServerOut{}, e
}
return BundleServerOut{
Server: server,
Customizer: newConfigCustomizer(server, TestBundleName),
}, nil
}
}
func WaitForOPAReady(lc fx.Lifecycle, ready opa.EmbeddedOPAReadyCH) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
select {
case <-ready:
return nil
case <-ctx.Done():
return fmt.Errorf("OPA Engine cannot be initialized before timeout")
}
},
})
}
type configCustomizer struct {
Server *sdktest.Server
BundleName string
}
func newConfigCustomizer(server *sdktest.Server, bundleName string) *configCustomizer {
return &configCustomizer{
Server: server,
BundleName: bundleName,
}
}
func (c configCustomizer) Customize(_ context.Context, cfg *opa.Config) {
cfg.Services = map[string]*oparest.Config{
BundleServiceKey: {
Name: BundleServiceKey,
URL: c.Server.URL(),
AllowInsecureTLS: true,
},
}
cfg.Bundles = map[string]*bundle.Source{
c.BundleName: {
Service: BundleServiceKey,
Resource: TestBundlePathPrefix + c.BundleName,
},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opatestserver
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
sdktest "github.com/open-policy-agent/opa/sdk/test"
"go.uber.org/fx"
"io/fs"
"path/filepath"
"strings"
)
var logger = log.New("OPA.Test")
type BundleServerOptions func(cfg *BundleServerConfig)
type BundleServerConfig struct {
BundleName string
BundleSources []fs.FS
}
// WithBundleSources is a BundleServerOptions to add bundle sources from bundle system
func WithBundleSources(fsys ...fs.FS) BundleServerOptions {
return func(cfg *BundleServerConfig) {
cfg.BundleSources = append(cfg.BundleSources, fsys...)
}
}
// WithBundleName is a BundleServerOptions to set bundle name
func WithBundleName(name string) BundleServerOptions {
return func(cfg *BundleServerConfig) {
cfg.BundleName = name
}
}
func NewBundleServer(ctx context.Context, opts ...BundleServerOptions) (*sdktest.Server, error) {
cfg := BundleServerConfig{
BundleName: "test",
BundleSources: []fs.FS{},
}
for _, fn := range opts {
fn(&cfg)
}
policies := map[string]string{}
for name, fsys := range cfg.BundleSources {
if e := loadBundleFiles(fsys, policies); e != nil {
logger.WithContext(ctx).Warnf("unable to load bundle [%s]: ", name, e)
continue
}
}
if len(policies) == 0 {
return nil, fmt.Errorf("failed to start OPA bundle server, unable to load any bundle")
}
ready := make(chan struct{}, 1)
defer func() { close(ready) }()
server, e := sdktest.NewServer(sdktest.MockBundle("/bundles/"+cfg.BundleName, policies), sdktest.Ready(ready))
if e != nil {
return nil, fmt.Errorf("failed to start OPA bundle server: %v", e)
}
logger.WithContext(ctx).Infof("OPA Bundles served at %q", server.URL())
return server, nil
}
func InitializeBundleServer(lc fx.Lifecycle, server *sdktest.Server) {
lc.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
server.Stop()
return nil
},
})
}
func loadBundleFiles(fsys fs.FS, dest map[string]string) error {
// find and read all files
files := map[string][]byte{}
rootPath := "."
e := fs.WalkDir(fsys, rootPath, func(path string, d fs.DirEntry, _ error) error {
// Note we ignore any error and let it walk through entire tree
if d.IsDir() {
return nil
}
data, e := fs.ReadFile(fsys, path)
if e != nil {
return e
}
if d.Name() == ".manifest" {
rootPath = filepath.Dir(path)
}
files[path] = data
return nil
})
if e != nil {
return e
} else if len(files) == 0 {
return fmt.Errorf("no files was found in bundle FS")
}
// prepare bundle content
for path, data := range files {
name, e := filepath.Rel(rootPath, path)
if e != nil {
name = path
}
if strings.HasSuffix(name, ".json") {
// nested data documents are not implemented in the dummy server
name = strings.ReplaceAll(path, "/", "_")
}
dest[name] = string(data)
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/test/sectest"
)
/*************************
Mocked Security
*************************/
const (
OwnerUserId = "20523d89-d5e9-40d0-afe3-3a74c298b55a"
AnotherUserId = "e7498b90-cec3-41fd-ac20-acd41769fb88"
RootTenantId = "7b3934fc-edc4-4a1c-9249-3dc7055eb124"
TenantId = "8eebb711-7d24-48fb-94da-361c573d7c20"
AnotherTenantId = "b11ef279-1309-4c43-8355-99c9d494097b"
ProviderId = "fe3ad89c-449f-42f2-b4f8-b10ab7bc0266"
)
func MemberAdminOptions() sectest.SecurityContextOptions {
return sectest.MockedAuthentication(func(d *sectest.SecurityDetailsMock) {
d.Username = "admin"
d.UserId = AnotherUserId
d.TenantId = TenantId
d.ProviderId = ProviderId
d.Permissions = utils.NewStringSet("IS_API_ADMIN", "VIEW", "MANAGE")
d.Roles = utils.NewStringSet("SUPERUSER")
d.Tenants = utils.NewStringSet(TenantId, AnotherTenantId)
})
}
func MemberOwnerOptions() sectest.SecurityContextOptions {
return sectest.MockedAuthentication(func(d *sectest.SecurityDetailsMock) {
d.Username = "testuser-member-owner"
d.UserId = OwnerUserId
d.TenantId = TenantId
d.ProviderId = ProviderId
d.Permissions = utils.NewStringSet("VIEW")
d.Roles = utils.NewStringSet("USER")
d.Tenants = utils.NewStringSet(TenantId)
})
}
func MemberNonOwnerOptions() sectest.SecurityContextOptions {
return sectest.MockedAuthentication(func(d *sectest.SecurityDetailsMock) {
d.Username = "testuser-member-non-owner"
d.UserId = AnotherUserId
d.TenantId = TenantId
d.ProviderId = ProviderId
d.Permissions = utils.NewStringSet("VIEW")
d.Roles = utils.NewStringSet("USER")
d.Tenants = utils.NewStringSet(TenantId)
})
}
func NonMemberAdminOptions() sectest.SecurityContextOptions {
return sectest.MockedAuthentication(func(d *sectest.SecurityDetailsMock) {
d.Username = "testuser-non-member"
d.UserId = AnotherUserId
d.TenantId = AnotherTenantId
d.ProviderId = ProviderId
d.Permissions = utils.NewStringSet("IS_API_ADMIN")
d.Roles = utils.NewStringSet("SUPERUSER")
d.Tenants = utils.NewStringSet(AnotherTenantId)
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opa
import (
"encoding/json"
"fmt"
"reflect"
)
// marshalMergedJSON merge extra into v, v have to be struct or map.
// "processors" are invoked after merge and before marshalling
func marshalMergedJSON(obj interface{}, extra map[string]interface{}, processors ...func(m map[string]interface{})) ([]byte, error) {
data, e := json.Marshal(obj)
if len(extra) == 0 && len(processors) == 0 || e != nil {
return data, e
}
// merge extra
var m map[string]interface{}
if e := json.Unmarshal(data, &m); e != nil {
return nil, fmt.Errorf("unable to merge JSON: %v", e)
}
for k, v := range extra {
m[k] = v
}
for _, fn := range processors {
fn(m)
}
return json.Marshal(m)
}
// minimizeMap recursively remove any zero valued entries
func minimizeMap(m map[string]interface{}) {
minimize(reflect.ValueOf(m))
}
func minimize(rv reflect.Value) (minimized reflect.Value, isZero bool) {
if rv.Kind() == reflect.Interface {
rv = rv.Elem()
}
isZero = !rv.IsValid() || rv.IsZero()
//nolint:exhaustive // we only deal with map and slice
switch rv.Kind() {
case reflect.Map:
rv = minimizeMapValue(rv)
isZero = isZero || rv.Len() == 0
case reflect.Slice:
rv = minimizeSliceValue(rv)
isZero = isZero || rv.Len() == 0
}
return rv, isZero
}
func minimizeMapValue(mapV reflect.Value) reflect.Value {
for _, k := range mapV.MapKeys() {
v := mapV.MapIndex(k)
v, zero := minimize(v)
if zero {
mapV.SetMapIndex(k, reflect.Value{})
} else {
mapV.SetMapIndex(k, v)
}
}
return mapV
}
func minimizeSliceValue(sliceV reflect.Value) reflect.Value {
newV := reflect.MakeSlice(sliceV.Type().Elem(), 0, sliceV.Len())
for i:=0; i < sliceV.Len(); i ++ {
v := sliceV.Index(i)
if v, zero := minimize(v); !zero {
newV = reflect.Append(newV, v)
}
}
return newV
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/opensearch-project/opensearch-go"
"github.com/opensearch-project/opensearch-go/opensearchutil"
"strconv"
"strings"
)
// BulkAction is intended to be used as an enum type for bulk actions
//
// [REF]: https://opensearch.org/docs/1.2/opensearch/rest-api/document-apis/bulk/#request-body
type BulkAction string
const (
BulkActionIndex BulkAction = "index" // Will add in a document and will override any duplicate (based on ID)
BulkActionCreate BulkAction = "create" // Will add a document if it doesn't exist or return an error
BulkActionUpdate BulkAction = "update" // Will update an existing document if it exists or return an error
BulkActionDelete BulkAction = "delete" // Will delete a document if it exists or return a `not_found`
)
func (c *RepoImpl[T]) BulkIndexer(ctx context.Context, action BulkAction, bulkItems *[]T, o ...Option[opensearchutil.BulkIndexerConfig]) (opensearchutil.BulkIndexerStats, error) {
arrBytes := make([][]byte, len(*bulkItems))
for i, item := range *bulkItems {
buffer, err := json.Marshal(item)
if err != nil {
return opensearchutil.BulkIndexerStats{}, err
}
arrBytes[i] = buffer
}
bi, err := c.client.BulkIndexer(ctx, action, arrBytes, o...)
if err != nil {
return opensearchutil.BulkIndexerStats{}, err
}
return bi.Stats(), nil
}
func (c *OpenClientImpl) BulkIndexer(ctx context.Context, action BulkAction, documents [][]byte, o ...Option[opensearchutil.BulkIndexerConfig]) (opensearchutil.BulkIndexer, error) {
options := make([]func(config *opensearchutil.BulkIndexerConfig), len(o))
for i, v := range o {
options[i] = v
}
//nolint:makezero
options = append(options, WithClient(c.client))
order.SortStable(c.beforeHook, order.OrderedFirstCompare)
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdBulk, Options: &options})
}
cfg := MakeConfig(options...)
bi, err := opensearchutil.NewBulkIndexer(*cfg)
if err != nil {
return nil, err
}
for _, item := range documents {
err = bi.Add(ctx, opensearchutil.BulkIndexerItem{
Action: string(action),
Body: strings.NewReader(string(item)),
})
if err != nil {
return nil, err
}
}
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdBulk, Options: &options, Err: &err})
}
if err = bi.Close(ctx); err != nil {
return bi, err
}
return bi, nil
}
func MakeConfig(options ...func(*opensearchutil.BulkIndexerConfig)) *opensearchutil.BulkIndexerConfig {
cfg := &opensearchutil.BulkIndexerConfig{}
for _, o := range options {
o(cfg)
}
return cfg
}
func WithClient(c *opensearch.Client) func(*opensearchutil.BulkIndexerConfig) {
return func(cfg *opensearchutil.BulkIndexerConfig) {
cfg.Client = c
}
}
func (e bulkCfgExt) WithWorkers(n int) func(*opensearchutil.BulkIndexerConfig) {
return func(cfg *opensearchutil.BulkIndexerConfig) {
cfg.NumWorkers = n
}
}
func (e bulkCfgExt) WithRefresh(b bool) func(*opensearchutil.BulkIndexerConfig) {
return func(cfg *opensearchutil.BulkIndexerConfig) {
cfg.Refresh = strconv.FormatBool(b)
}
}
func (e bulkCfgExt) WithIndex(i string) func(*opensearchutil.BulkIndexerConfig) {
return func(cfg *opensearchutil.BulkIndexerConfig) {
cfg.Index = i
}
}
type bulkCfgExt struct {
opensearchutil.BulkIndexerConfig
}
var BulkIndexer = bulkCfgExt{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/opensearch-project/opensearch-go"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"github.com/opensearch-project/opensearch-go/opensearchutil"
"go.uber.org/fx"
"io"
"net/http"
"reflect"
)
var (
ErrCreatingNewClient = errors.New("unable to create opensearch client")
)
type Request interface {
opensearchapi.SearchRequest |
opensearchapi.SearchTemplateRequest |
opensearchapi.IndicesCreateRequest |
opensearchapi.IndexRequest |
opensearchutil.BulkIndexerConfig |
opensearchapi.IndicesDeleteRequest |
opensearchapi.IndicesGetRequest |
opensearchapi.IndicesPutAliasRequest |
opensearchapi.IndicesDeleteAliasRequest |
opensearchapi.IndicesPutIndexTemplateRequest |
opensearchapi.IndicesDeleteIndexTemplateRequest |
opensearchapi.PingRequest
}
type OpenClient interface {
Search(ctx context.Context, o ...Option[opensearchapi.SearchRequest]) (*opensearchapi.Response, error)
SearchTemplate(ctx context.Context, body io.Reader, o ...Option[opensearchapi.SearchTemplateRequest]) (*opensearchapi.Response, error)
Index(ctx context.Context, index string, body io.Reader, o ...Option[opensearchapi.IndexRequest]) (*opensearchapi.Response, error)
BulkIndexer(ctx context.Context, action BulkAction, bulkItems [][]byte, o ...Option[opensearchutil.BulkIndexerConfig]) (opensearchutil.BulkIndexer, error)
IndicesCreate(ctx context.Context, index string, o ...Option[opensearchapi.IndicesCreateRequest]) (*opensearchapi.Response, error)
IndicesGet(ctx context.Context, index string, o ...Option[opensearchapi.IndicesGetRequest]) (*opensearchapi.Response, error)
IndicesDelete(ctx context.Context, index []string, o ...Option[opensearchapi.IndicesDeleteRequest]) (*opensearchapi.Response, error)
IndicesPutAlias(ctx context.Context, index []string, name string, o ...Option[opensearchapi.IndicesPutAliasRequest]) (*opensearchapi.Response, error)
IndicesDeleteAlias(ctx context.Context, index []string, name []string, o ...Option[opensearchapi.IndicesDeleteAliasRequest]) (*opensearchapi.Response, error)
IndicesPutIndexTemplate(ctx context.Context, name string, body io.Reader, o ...Option[opensearchapi.IndicesPutIndexTemplateRequest]) (*opensearchapi.Response, error)
IndicesDeleteIndexTemplate(ctx context.Context, name string, o ...Option[opensearchapi.IndicesDeleteIndexTemplateRequest]) (*opensearchapi.Response, error)
Ping(ctx context.Context, o ...Option[opensearchapi.PingRequest]) (*opensearchapi.Response, error)
AddBeforeHook(hook BeforeHook)
AddAfterHook(hook AfterHook)
RemoveBeforeHook(hook BeforeHook)
RemoveAfterHook(hook AfterHook)
}
type Option[T Request] func(request *T)
const (
// FxGroup defines the FX group for the OpenSearch
FxGroup = "opensearch"
)
type newClientDI struct {
fx.In
Config opensearch.Config
BeforeHook []BeforeHook `group:"opensearch"`
AfterHook []AfterHook `group:"opensearch"`
}
func NewClient(di newClientDI) (OpenClient, error) {
client, err := opensearch.NewClient(di.Config)
if err != nil {
return nil, fmt.Errorf("%w, %v", ErrCreatingNewClient, err)
}
order.SortStable(di.BeforeHook, order.OrderedFirstCompare)
order.SortStable(di.AfterHook, order.OrderedFirstCompare)
openClientImpl := &OpenClientImpl{
client: client,
beforeHook: di.BeforeHook,
afterHook: di.AfterHook,
}
return openClientImpl, nil
}
type configDI struct {
fx.In
Properties *Properties
TLSCertsManager certs.Manager `optional:"true"`
}
func NewConfig(ctx *bootstrap.ApplicationContext, di configDI) (opensearch.Config, error) {
conf := opensearch.Config{
Addresses: di.Properties.Addresses,
Username: di.Properties.Username,
Password: di.Properties.Password,
}
if di.Properties.TLS.Enable {
if di.TLSCertsManager == nil {
return conf, fmt.Errorf("TLS is enabled of OpenSearch, but certificate manager is not initialized")
}
tlsSrc, err := di.TLSCertsManager.Source(ctx, certs.WithSourceProperties(&di.Properties.TLS.Certs))
if err != nil {
return conf, err
}
tlsConf, e := tlsSrc.TLSConfig(ctx)
if e != nil {
return conf, fmt.Errorf("failed to initialize TLS for OpenSearch: %v", e)
}
conf.Transport = &http.Transport{
TLSClientConfig: tlsConf,
}
}
return conf, nil
}
type OpenClientImpl struct {
client *opensearch.Client
beforeHook []BeforeHook
afterHook []AfterHook
}
// CommandType lets the hooks know what command is being run
type CommandType int
const (
UnknownCommand string = "unknown"
)
const (
CmdSearch CommandType = iota
CmdSearchTemplate
CmdIndex
CmdIndicesCreate
CmdIndicesGet
CmdIndicesDelete
CmdIndicesPutAlias
CmdIndicesDeleteAlias
CmdIndicesPutIndexTemplate
CmdIndicesDeleteIndexTemplate
CmdPing
CmdBulk
)
var CmdToString = map[CommandType]string{
CmdSearch: "search",
CmdSearchTemplate: "search template",
CmdIndex: "index",
CmdIndicesCreate: "indices create",
CmdIndicesGet: "indices get",
CmdIndicesDelete: "indices delete",
CmdIndicesPutAlias: "indices put alias",
CmdIndicesDeleteAlias: "indices delete alias",
CmdIndicesPutIndexTemplate: "indices put index template",
CmdIndicesDeleteIndexTemplate: "indices delete index template",
CmdPing: "ping",
CmdBulk: "bulk",
}
// String will return the command in string format. If the command is not found
// the UnknownCommand string will be returned
func (c CommandType) String() string {
val, ok := CmdToString[c]
if !ok {
logger.Errorf("unknown command: %v", c)
return UnknownCommand
}
return val
}
func (c *OpenClientImpl) AddBeforeHook(hook BeforeHook) {
c.beforeHook = append(c.beforeHook, hook)
order.SortStable(c.beforeHook, order.OrderedFirstCompare)
}
func (c *OpenClientImpl) AddAfterHook(hook AfterHook) {
c.afterHook = append(c.afterHook, hook)
order.SortStable(c.afterHook, order.OrderedFirstCompare)
}
// RemoveBeforeHook will remove the given BeforeHook. To ensure your hook is removable,
// the hook should implement the Identifier interface. If not, your hooks should be
// distinct in the eyes of reflect.DeepEqual, otherwise the hook will not be removed.
func (c *OpenClientImpl) RemoveBeforeHook(hook BeforeHook) {
if hookWithIdentifier, ok := hook.(Identifier); ok {
for i, beforeHook := range c.beforeHook {
if beforeHookWithIdentifier, hok := beforeHook.(Identifier); hok {
if hookWithIdentifier.ID() == beforeHookWithIdentifier.ID() {
c.beforeHook = utils.RemoveStable(c.beforeHook, i)
}
}
}
return
}
for i, h := range c.beforeHook {
if reflect.DeepEqual(h, hook) {
c.beforeHook = utils.RemoveStable(c.beforeHook, i)
}
}
}
// RemoveAfterHook will remove the given AfterHook. To ensure your hook is removable,
// the hook should implement the Identifier interface. If not, your hooks should be
// distinct in the eyes of reflect.DeepEqual, otherwise the hook will not be removed.
func (c *OpenClientImpl) RemoveAfterHook(hook AfterHook) {
if hookWithIdentifier, ok := hook.(Identifier); ok {
for i, afterHook := range c.afterHook {
if afterHookWithIdentifier, hok := afterHook.(Identifier); hok {
if hookWithIdentifier.ID() == afterHookWithIdentifier.ID() {
c.afterHook = utils.RemoveStable(c.afterHook, i)
}
}
}
return
}
for i, h := range c.afterHook {
if reflect.DeepEqual(h, hook) {
c.afterHook = utils.RemoveStable(c.afterHook, i)
}
}
}
// BeforeContext is the context given to a BeforeHook
//
// Options will be in the form *[]func(request *Request){}, example:
//
// options := make([]func(request *opensearchapi.SearchRequest), 0)
// BeforeContext{Options: &options}
type BeforeContext struct {
cmd CommandType
Options interface{}
}
func (c *BeforeContext) CommandType() CommandType {
return c.cmd
}
type BeforeHook interface {
Before(ctx context.Context, before BeforeContext) context.Context
}
type Identifier interface {
ID() string
}
// BeforeHookBase provides a way to create an BeforeHook, similar to BeforeHookFunc,
// but in a way that implements the Identifier interface so that it can be removed using
// the RemoveBeforeHook function
type BeforeHookBase struct {
Identifier string
F func(ctx context.Context, after BeforeContext) context.Context
}
func (s BeforeHookBase) Before(ctx context.Context, before BeforeContext) context.Context {
return s.F(ctx, before)
}
func (s BeforeHookBase) ID() string {
return s.Identifier
}
type BeforeHookFunc func(ctx context.Context, before BeforeContext) context.Context
func (f BeforeHookFunc) Before(ctx context.Context, before BeforeContext) context.Context {
return f(ctx, before)
}
// AfterContext is the context given to a AfterHook
//
// Options will be in the form *[]func(request *Request){} example:
//
// options := make([]func(request *opensearchapi.SearchRequest), 0)
// AfterContext{Options: &options}
//
// Resp and Err can be modified before they are returned out of Request of OpenClientImpl
// example being OpenClientImpl.Search
type AfterContext struct {
cmd CommandType
Options interface{}
Resp *opensearchapi.Response
Err *error
}
func (c *AfterContext) CommandType() CommandType {
return c.cmd
}
type AfterHook interface {
After(ctx context.Context, after AfterContext) context.Context
}
// AfterHookBase provides a way to create an AfterHook, similar to AfterHookFunc.
// but in a way that implements the Identifier interface so that it can be removed using
// the RemoveAfterHook function
type AfterHookBase struct {
Identifier string
F func(ctx context.Context, after AfterContext) context.Context
}
func (s AfterHookBase) After(ctx context.Context, after AfterContext) context.Context {
return s.F(ctx, after)
}
func (s AfterHookBase) ID() string {
return s.Identifier
}
// AfterHookFunc provides a way to easily create a AfterHook - however hooks created in this
// manner are not able to be deleted from the hook slice
type AfterHookFunc func(ctx context.Context, after AfterContext) context.Context
func (f AfterHookFunc) After(ctx context.Context, after AfterContext) context.Context {
return f(ctx, after)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
)
type HealthIndicator struct {
client OpenClient
}
func (i *HealthIndicator) Name() string {
return "opensearch"
}
func NewHealthIndicator(client OpenClient) *HealthIndicator {
return &HealthIndicator{
client: client,
}
}
func (i *HealthIndicator) Health(c context.Context, options health.Options) health.Health {
resp, err := i.client.Ping(c)
if err != nil {
logger.WithContext(c).Errorf("unable to ping opensearch: %v", err)
return health.NewDetailedHealth(health.StatusDown, "opensearch ping failed", nil)
}
if resp.IsError() {
logger.WithContext(c).Errorf("unable to ping opensearch: %v", resp.String())
return health.NewDetailedHealth(health.StatusDown, "opensearch ping failed", nil)
}
return health.NewDetailedHealth(health.StatusUp, "opensearch ping succeeded", nil)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"io"
)
func (c *RepoImpl[T]) Index(ctx context.Context, index string, document T, o ...Option[opensearchapi.IndexRequest]) error {
var buffer bytes.Buffer
err := json.NewEncoder(&buffer).Encode(document)
if err != nil {
return err
}
resp, err := c.client.Index(ctx, index, &buffer, o...)
if err != nil {
return err
}
if resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return fmt.Errorf("error status code: %d", resp.StatusCode)
}
return nil
}
func (c *OpenClientImpl) Index(ctx context.Context, index string, body io.Reader, o ...Option[opensearchapi.IndexRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.IndexRequest), len(o))
for i, v := range o {
options[i] = v
}
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdIndex, Options: &options})
}
//nolint:makezero
options = append(options, Index.WithContext(ctx))
resp, err := c.client.API.Index(index, body, options...)
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdIndex, Options: &options, Resp: resp, Err: &err})
}
return resp, err
}
// indexExt can be extended
// func (s indexExt) WithSomething() func(request *opensearchapi.Index) {
// return func(request *opensearchapi.Index) {
// }
// }
type indexExt struct {
opensearchapi.Index
}
var Index = indexExt{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
)
func (c *RepoImpl[T]) IndicesCreate(
ctx context.Context,
index string,
mapping interface{},
o ...Option[opensearchapi.IndicesCreateRequest],
) error {
var buffer bytes.Buffer
err := json.NewEncoder(&buffer).Encode(mapping)
if err != nil {
return fmt.Errorf("unable to encode mapping: %w", err)
}
o = append(o, IndicesCreate.WithBody(&buffer))
resp, err := c.client.IndicesCreate(ctx, index, o...)
if err != nil {
return err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return fmt.Errorf("error status code: %d", resp.StatusCode)
}
return nil
}
func (c *OpenClientImpl) IndicesCreate(
ctx context.Context,
index string,
o ...Option[opensearchapi.IndicesCreateRequest],
) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.IndicesCreateRequest), len(o))
for i, v := range o {
options[i] = v
}
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdIndicesCreate, Options: &options})
}
//nolint:makezero
options = append(options, IndicesCreate.WithContext(ctx))
resp, err := c.client.Indices.Create(index, options...)
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdIndicesCreate, Options: &options, Resp: resp, Err: &err})
}
return resp, err
}
type indicesCreateExt struct {
opensearchapi.IndicesCreate
}
var IndicesCreate = indicesCreateExt{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
)
func (c *RepoImpl[T]) IndicesDelete(ctx context.Context, index []string, o ...Option[opensearchapi.IndicesDeleteRequest]) error {
resp, err := c.client.IndicesDelete(ctx, index, o...)
if err != nil {
return err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return fmt.Errorf("error status code: %d", resp.StatusCode)
}
return nil
}
func (c *OpenClientImpl) IndicesDelete(ctx context.Context, index []string, o ...Option[opensearchapi.IndicesDeleteRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.IndicesDeleteRequest), len(o))
for i, v := range o {
options[i] = v
}
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdIndicesDelete, Options: &options})
}
//nolint:makezero
options = append(options, IndicesDelete.WithContext(ctx))
resp, err := c.client.API.Indices.Delete(index, options...)
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdIndicesDelete, Options: &options, Resp: resp, Err: &err})
}
return resp, err
}
type indicesDeleteExt struct {
opensearchapi.IndicesDelete
}
var IndicesDelete = indicesDeleteExt{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
)
func (c *RepoImpl[T]) IndicesDeleteAlias(ctx context.Context, index []string, name []string, o ...Option[opensearchapi.IndicesDeleteAliasRequest]) error {
resp, err := c.client.IndicesDeleteAlias(ctx, index, name, o...)
if err != nil {
return err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return fmt.Errorf("error status code: %d", resp.StatusCode)
}
return nil
}
func (c *OpenClientImpl) IndicesDeleteAlias(ctx context.Context, index []string, name []string, o ...Option[opensearchapi.IndicesDeleteAliasRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.IndicesDeleteAliasRequest), len(o))
for i, v := range o {
options[i] = v
}
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdIndicesDeleteAlias, Options: &options})
}
//nolint:makezero
options = append(options, IndicesDeleteAlias.WithContext(ctx))
resp, err := c.client.API.Indices.DeleteAlias(index, name, options...)
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdIndicesDeleteAlias, Options: &options, Resp: resp, Err: &err})
}
return resp, err
}
type indicesDeleteAlias struct {
opensearchapi.IndicesDeleteAlias
}
var IndicesDeleteAlias = indicesDeleteAlias{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
)
func (c *RepoImpl[T]) IndicesDeleteIndexTemplate(ctx context.Context, name string, o ...Option[opensearchapi.IndicesDeleteIndexTemplateRequest]) error {
resp, err := c.client.IndicesDeleteIndexTemplate(ctx, name, o...)
if err != nil {
return err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return fmt.Errorf("error status code: %d", resp.StatusCode)
}
return nil
}
func (c *OpenClientImpl) IndicesDeleteIndexTemplate(ctx context.Context, name string, o ...Option[opensearchapi.IndicesDeleteIndexTemplateRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.IndicesDeleteIndexTemplateRequest), len(o))
for i, v := range o {
options[i] = v
}
//nolint:makezero
options = append(options, IndicesDeleteIndexTemplate.WithContext(ctx))
resp, err := c.client.API.Indices.DeleteIndexTemplate(name, options...)
return resp, err
}
type indicesDeleteIndexTemplate struct {
opensearchapi.IndicesDeleteIndexTemplate
}
var IndicesDeleteIndexTemplate = indicesDeleteIndexTemplate{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"encoding/json"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"io"
"net/http"
)
// IndicesDetail response follows opensearch spec
// [format] https://opensearch.org/docs/latest/opensearch/rest-api/index-apis/get-index/#response-body-fields
type IndicesDetail struct {
Aliases map[string]interface{} `json:"aliases"`
Mappings map[string]interface{} `json:"mappings"`
Settings struct {
Index struct {
CreationDate string `json:"creation_date"`
NumberOfShards string `json:"number_of_shards"`
NumberOfReplicas string `json:"number_of_replicas"`
Uuid string `json:"uuid"`
Version struct {
Created string `json:"created"`
} `json:"version"`
ProvidedName string `json:"provided_name"`
} `json:"index"`
} `json:"Settings"`
}
func (c *RepoImpl[T]) IndicesGet(ctx context.Context, index string, o ...Option[opensearchapi.IndicesGetRequest]) (*IndicesDetail, error) {
resp, err := c.client.IndicesGet(ctx, index, o...)
if err != nil {
return nil, err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return nil, fmt.Errorf("error status code: %d", resp.StatusCode)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
indicesDetail := make(map[string]*IndicesDetail)
err = json.Unmarshal(respBody, &indicesDetail)
if err != nil {
return nil, err
}
if len(indicesDetail) > 1 {
return nil, fmt.Errorf("error status code: %d, more than one index exists, with the same alias/name: %s ", http.StatusInternalServerError, index)
}
// This is needed because the first level of the nested object returned will be an unknown index name (Assuming we use an alias)
key := ""
for k, _ := range indicesDetail {
key += k
}
return indicesDetail[key], nil
}
func (c *OpenClientImpl) IndicesGet(ctx context.Context, index string, o ...Option[opensearchapi.IndicesGetRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.IndicesGetRequest), len(o))
for i, v := range o {
options[i] = v
}
//nolint:makezero
options = append(options, IndicesGet.WithContext(ctx))
resp, err := c.client.Indices.Get([]string{index}, options...)
return resp, err
}
type indicesGetExt struct {
opensearchapi.IndicesGet
}
var IndicesGet = indicesGetExt{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
)
func (c *RepoImpl[T]) IndicesPutAlias(ctx context.Context,
index []string,
name string,
o ...Option[opensearchapi.IndicesPutAliasRequest],
) error {
resp, err := c.client.IndicesPutAlias(ctx, index, name, o...)
if err != nil {
return err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return fmt.Errorf("error status code: %d", resp.StatusCode)
}
return nil
}
func (c *OpenClientImpl) IndicesPutAlias(ctx context.Context, index []string, name string, o ...Option[opensearchapi.IndicesPutAliasRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.IndicesPutAliasRequest), len(o))
for i, v := range o {
options[i] = v
}
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdIndicesPutAlias, Options: &options})
}
//nolint:makezero
options = append(options, IndicesPutAlias.WithContext(ctx))
resp, err := c.client.API.Indices.PutAlias(index, name, options...)
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdIndicesPutAlias, Options: &options, Resp: resp, Err: &err})
}
return resp, err
}
type indicesPutAlias struct {
opensearchapi.IndicesPutAlias
}
var IndicesPutAlias = indicesPutAlias{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"io"
)
func (c *RepoImpl[T]) IndicesPutIndexTemplate(ctx context.Context, name string, body interface{}, o ...Option[opensearchapi.IndicesPutIndexTemplateRequest]) error {
var buffer bytes.Buffer
err := json.NewEncoder(&buffer).Encode(body)
if err != nil {
return fmt.Errorf("unable to encode mapping: %w", err)
}
resp, err := c.client.IndicesPutIndexTemplate(ctx, name, &buffer, o...)
if err != nil {
return err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return fmt.Errorf("error status code: %d", resp.StatusCode)
}
return nil
}
func (c *OpenClientImpl) IndicesPutIndexTemplate(ctx context.Context, name string, body io.Reader, o ...Option[opensearchapi.IndicesPutIndexTemplateRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.IndicesPutIndexTemplateRequest), len(o))
for i, v := range o {
options[i] = v
}
//nolint:makezero
options = append(options, IndicesPutIndexTemplate.WithContext(ctx))
resp, err := c.client.API.Indices.PutIndexTemplate(name, body, options...)
return resp, err
}
type indicesPutIndexTemplate struct {
opensearchapi.IndicesPutIndexTemplate
}
var IndicesPutIndexTemplate = indicesPutIndexTemplate{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"github.com/cisco-open/go-lanai/pkg/actuator/health"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"go.uber.org/fx"
)
var logger = log.New("Search")
var Module = &bootstrap.Module{
Precedence: bootstrap.OpenSearchPrecedence,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(BindOpenSearchProperties),
fx.Provide(NewConfig),
fx.Provide(NewClient),
fx.Provide(tracingProvider()),
fx.Invoke(registerHealth),
},
}
func Use() {
bootstrap.Register(Module)
}
type regDI struct {
fx.In
HealthRegistrar health.Registrar `optional:"true"`
OpenClient OpenClient `optional:"true"`
}
func registerHealth(di regDI) {
if di.HealthRegistrar == nil || di.OpenClient == nil {
return
}
di.HealthRegistrar.MustRegister(NewHealthIndicator(di.OpenClient))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"fmt"
"github.com/opensearch-project/opensearch-go/opensearchapi"
)
func (c *RepoImpl[T]) Ping(
ctx context.Context,
o ...Option[opensearchapi.PingRequest],
) error {
resp, err := c.client.Ping(ctx, o...)
if err != nil {
return err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return fmt.Errorf("error status code: %d", resp.StatusCode)
}
return nil
}
func (c *OpenClientImpl) Ping(ctx context.Context, o ...Option[opensearchapi.PingRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.PingRequest), len(o))
for i, v := range o {
options[i] = v
}
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdPing, Options: &options})
}
//nolint:makezero
options = append(options, Ping.WithContext(ctx))
resp, err := c.client.Ping(options...)
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdPing, Options: &options, Resp: resp, Err: &err})
}
return resp, err
}
type pingExt struct {
opensearchapi.Ping
}
var Ping = pingExt{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"embed"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/pkg/errors"
)
const (
PropertiesPrefix = "data.opensearch"
)
//go:embed defaults-opensearch.yml
var defaultConfigFS embed.FS
type Properties struct {
Addresses []string `json:"addresses"`
Username string `json:"username"`
Password string `json:"password"`
TLS TLS `json:"tls"`
}
type TLS struct {
Enable bool `json:"enable"`
Certs certs.SourceProperties `json:"certs"`
}
func NewOpenSearchProperties() *Properties {
return &Properties{} // None by default, they should all be defined in the defaults-opensearch.yml
}
func BindOpenSearchProperties(ctx *bootstrap.ApplicationContext) *Properties {
props := NewOpenSearchProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind OpenSearchProperties"))
}
return props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"github.com/opensearch-project/opensearch-go/opensearchutil"
)
// NewRepo will return a OpenSearch repository for any model type T
func NewRepo[T any](model *T, client OpenClient) Repo[T] {
return &RepoImpl[T]{
client: client,
}
}
type Repo[T any] interface {
// Search will search the cluster for data.
//
// The data will be unmarshalled and returned to the dest argument.
// The body argument should follow the Search request body [Format].
//
// [Format]: https://opensearch.org/docs/latest/opensearch/rest-api/search/#request-body
Search(ctx context.Context, dest *[]T, body interface{}, o ...Option[opensearchapi.SearchRequest]) (int, error)
// SearchTemplate allows to use the Mustache language to pre-render a search definition
//
// The data will be unmarshalled and returned to the dest argument.
// The body argument should follow the Search request body [Format].
//
// [Format]: https://opensearch.org/docs/latest/opensearch/rest-api/search/#request-body
SearchTemplate(ctx context.Context, dest *[]T, body interface{}, o ...Option[opensearchapi.SearchTemplateRequest]) (int, error)
// Index will create a new Document in the index that is defined.
//
// The index argument defines the index name that the document should be stored in.
Index(ctx context.Context, index string, document T, o ...Option[opensearchapi.IndexRequest]) error
// BulkIndexer will process bulk requests of a single action type.
//
// The index argument defines the index name that the bulk action will target.
// The action argument must be one of: ("index", "create", "delete", "update").
// The bulkItems argument is the array of struct items to be actioned.
//
// [Ref]: https://pkg.go.dev/github.com/opensearch-project/opensearch-go/opensearchutil#BulkIndexerItem
BulkIndexer(ctx context.Context, action BulkAction, bulkItems *[]T, o ...Option[opensearchutil.BulkIndexerConfig]) (opensearchutil.BulkIndexerStats, error)
// IndicesCreate will create a new index in the cluster.
//
// The index argument defines the index name to be created.
// The mapping argument should follow the Index Create request body [Format].
//
// [Format]: https://opensearch.org/docs/latest/opensearch/rest-api/index-apis/create-index/#request-body
IndicesCreate(ctx context.Context, index string, mapping interface{}, o ...Option[opensearchapi.IndicesCreateRequest]) error
// IndicesGet will return information about an index
//
// The index argument defines the index name we want to get
IndicesGet(ctx context.Context, index string, o ...Option[opensearchapi.IndicesGetRequest]) (*IndicesDetail, error)
// IndicesDelete will delete an index from the cluster.
//
// The index argument defines the index name to be deleted.
//
// [Format]: https://opensearch.org/docs/latest/opensearch/rest-api/index-apis/delete-index/
IndicesDelete(ctx context.Context, index []string, o ...Option[opensearchapi.IndicesDeleteRequest]) error
// IndicesPutAlias will create or update an alias
//
// The index argument defines the index that the alias should point to
// The name argument defines the name of the new alias
//
// [Format]: https://opensearch.org/docs/latest/opensearch/rest-api/alias/#request-body
IndicesPutAlias(ctx context.Context, index []string, name string, o ...Option[opensearchapi.IndicesPutAliasRequest]) error
// IndicesDeleteAlias deletes an alias
//
// The index argument defines the index that the alias points to
// The name argument defines the name of the alias we would like to delete
//
// [Format]: https://opensearch.org/docs/latest/opensearch/rest-api/alias/#request-body
IndicesDeleteAlias(ctx context.Context, index []string, name []string, o ...Option[opensearchapi.IndicesDeleteAliasRequest]) error
// IndicesPutIndexTemplate will create or update an alias
//
// The name argument defines the name of the template
// The body argument defines the specified template options to apply (refer to [Format])
//
// [Format]: https://opensearch.org/docs/latest/opensearch/index-templates/#index-template-options
IndicesPutIndexTemplate(ctx context.Context, name string, body interface{}, o ...Option[opensearchapi.IndicesPutIndexTemplateRequest]) error
// IndicesDeleteIndexTemplate deletes an index template
//
// The name argument defines the name of the template to delete
IndicesDeleteIndexTemplate(ctx context.Context, name string, o ...Option[opensearchapi.IndicesDeleteIndexTemplateRequest]) error
// Ping will ping the OpenSearch cluster. If no error is returned, then the ping was successful
Ping(ctx context.Context, o ...Option[opensearchapi.PingRequest]) error
AddBeforeHook(hook BeforeHook)
AddAfterHook(hook AfterHook)
RemoveBeforeHook(hook BeforeHook)
RemoveAfterHook(hook AfterHook)
}
type RepoImpl[T any] struct {
client OpenClient
}
func (c *RepoImpl[T]) AddBeforeHook(hook BeforeHook) {
c.client.AddBeforeHook(hook)
}
func (c *RepoImpl[T]) AddAfterHook(hook AfterHook) {
c.client.AddAfterHook(hook)
}
func (c *RepoImpl[T]) RemoveBeforeHook(hook BeforeHook) {
c.client.RemoveBeforeHook(hook)
}
func (c *RepoImpl[T]) RemoveAfterHook(hook AfterHook) {
c.client.RemoveAfterHook(hook)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"bytes"
"encoding/json"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"io"
)
// UnmarshalResponse will take the response, read the body out of it and then
// place the bytes that were read back into the body so it can be used again after
// this call
func UnmarshalResponse[T any](resp *opensearchapi.Response) (*T, error) {
var model T
var respBody []byte
var err error
if resp.Body != nil {
respBody, err = io.ReadAll(resp.Body)
if err != nil {
return &model, err
}
}
// restore the resp.Body back to original state
resp.Body = io.NopCloser(bytes.NewBuffer(respBody))
err = json.Unmarshal(respBody, &model)
if err != nil {
return &model, err
}
return &model, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"io"
"net/http"
)
var (
ErrIndexNotFound = errors.New("index not found")
)
// SearchResponse modeled after https://opensearch.org/docs/latest/opensearch/rest-api/search/#response-body
type SearchResponse[T any] struct {
Took int `json:"took"`
TimedOut bool `json:"timed_out"`
Shards struct {
Total int `json:"total"`
Successful int `json:"successful"`
Skipped int `json:"skipped"`
Failed int `json:"failed"`
} `json:"_shards"`
Hits struct {
MaxScore float64 `json:"max_score"`
Total struct {
Value int `json:"value"`
} `json:"total"`
Hits []struct {
Index string `json:"_index"`
ID string `json:"_id"`
Score float64 `json:"_score"`
Source T `json:"_source"`
} `json:"hits"`
} `json:"hits"`
}
func (c *RepoImpl[T]) Search(ctx context.Context, dest *[]T, body interface{}, o ...Option[opensearchapi.SearchRequest]) (hits int, err error) {
var buffer bytes.Buffer
err = json.NewEncoder(&buffer).Encode(body)
if err != nil {
return 0, fmt.Errorf("unable to encode mapping: %w", err)
}
o = append(o, Search.WithBody(&buffer))
resp, err := c.client.Search(ctx, o...)
if err != nil {
return 0, err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Errorf("error response: %s", resp.String())
if resp.StatusCode == http.StatusNotFound {
return 0, fmt.Errorf("%w", ErrIndexNotFound)
} else {
return 0, fmt.Errorf("error status code: %d", resp.StatusCode)
}
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var searchResp SearchResponse[T]
err = json.Unmarshal(respBody, &searchResp)
if err != nil {
return 0, err
}
retModel := make([]T, len(searchResp.Hits.Hits))
for i, hits := range searchResp.Hits.Hits {
retModel[i] = hits.Source
}
*dest = retModel
return searchResp.Hits.Total.Value, nil
}
func (c *OpenClientImpl) Search(ctx context.Context, o ...Option[opensearchapi.SearchRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.SearchRequest), len(o))
for i, v := range o {
options[i] = v
}
order.SortStable(c.beforeHook, order.OrderedFirstCompare)
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdSearch, Options: &options})
}
//nolint:makezero
options = append(options, Search.WithContext(ctx))
resp, err := c.client.API.Search(options...)
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdSearch, Options: &options, Resp: resp, Err: &err})
}
return resp, err
}
// searchExt can be extended
//
// func (s searchExt) WithSomething() func(request *opensearchapi.SearchRequest) {
// return func(request *opensearchapi.SearchRequest) {
// }
// }
type searchExt struct {
opensearchapi.Search
}
var Search = searchExt{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"io"
)
func (c *RepoImpl[T]) SearchTemplate(ctx context.Context, dest *[]T, body interface{}, o ...Option[opensearchapi.SearchTemplateRequest]) (hits int, err error) {
var buffer bytes.Buffer
err = json.NewEncoder(&buffer).Encode(body)
if err != nil {
return 0, fmt.Errorf("unable to encode mapping: %w", err)
}
resp, err := c.client.SearchTemplate(ctx, &buffer, o...)
if err != nil {
return 0, err
}
if resp != nil && resp.IsError() {
logger.WithContext(ctx).Debugf("error response: %s", resp.String())
return 0, fmt.Errorf("error status code: %d", resp.StatusCode)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var searchResp SearchResponse[T]
err = json.Unmarshal(respBody, &searchResp)
if err != nil {
return 0, err
}
retModel := make([]T, len(searchResp.Hits.Hits))
for i, hits := range searchResp.Hits.Hits {
retModel[i] = hits.Source
}
*dest = retModel
return searchResp.Hits.Total.Value, nil
}
func (c *OpenClientImpl) SearchTemplate(ctx context.Context, body io.Reader, o ...Option[opensearchapi.SearchTemplateRequest]) (*opensearchapi.Response, error) {
options := make([]func(request *opensearchapi.SearchTemplateRequest), len(o))
for i, v := range o {
options[i] = v
}
order.SortStable(c.beforeHook, order.OrderedFirstCompare)
for _, hook := range c.beforeHook {
ctx = hook.Before(ctx, BeforeContext{cmd: CmdSearchTemplate, Options: &options})
}
//nolint:makezero
options = append(options, SearchTemplate.WithContext(ctx))
resp, err := c.client.API.SearchTemplate(body, options...)
for _, hook := range c.afterHook {
ctx = hook.After(ctx, AfterContext{cmd: CmdSearchTemplate, Options: &options, Resp: resp, Err: &err})
}
return resp, err
}
type searchTemplateExt struct {
opensearchapi.SearchTemplate
}
var SearchTemplate = searchTemplateExt{}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/opensearch"
"math/rand"
"strings"
"time"
)
type GenericAuditEvent struct {
Client_ID string
Description string
Details string
ID string
Keywords string
Orig_User string
Owner_Tenant_ID string
Parent_Span_ID string
Provider_ID string
Security string
Service string
Severity string
Span_ID string
SubType string
Tenant_ID string
Tenant_Name string
Time time.Time
Time_Bucket int
Trace string
Trace_ID string
Type string
User_ID string
Username string
}
func SetupPrepareOpenSearchData(
ctx context.Context,
repo opensearch.Repo[GenericAuditEvent],
startDate time.Time,
endDate time.Time,
) (context.Context, error) {
// We don't care if we can't delete this indices - it might not exist
//nolint:errcheck
_ = repo.IndicesDelete(ctx, []string{"auditlog"})
events := []GenericAuditEvent{}
CreateData(10, startDate, endDate, &events)
_, err := repo.BulkIndexer(
ctx,
"index",
&events,
opensearch.BulkIndexer.WithIndex("auditlog"),
opensearch.BulkIndexer.WithWorkers(1),
opensearch.BulkIndexer.WithRefresh(true),
)
if err != nil {
return ctx, err
}
return ctx, nil
}
// CreateData will create a slice of random generated GenericAuditEvents
// The time between each event will be uniformly distributed between the startT and endT
func CreateData(numOfDocuments int, startT time.Time, endT time.Time, dest *[]GenericAuditEvent) {
timeDelta := endT.Sub(startT) / time.Duration(numOfDocuments)
currentTime := startT
genericEvents := make([]GenericAuditEvent, numOfDocuments)
for i := 0; i < numOfDocuments; i++ {
PopulateSourceWithDeterministicData(&genericEvents[i])
currentTime = currentTime.Add(timeDelta)
genericEvents[i].Time = currentTime
}
*dest = genericEvents
}
func PopulateSourceWithDeterministicData(source *GenericAuditEvent) {
subTypes := []string{"W", "SCHEDULE_TASK", "SYNCHRONIZED"}
Types := []string{"GP", "DEVICE", "DP"}
source.Type = Types[int(src.Int63())%len(Types)]
source.SubType = subTypes[int(src.Int63())%len(subTypes)]
source.Trace_ID = RandStringBytesMaskImprSrcSB(5)
source.Span_ID = RandStringBytesMaskImprSrcSB(5)
source.Parent_Span_ID = RandStringBytesMaskImprSrcSB(5)
source.Client_ID = RandStringBytesMaskImprSrcSB(5)
source.Tenant_ID = RandStringBytesMaskImprSrcSB(5)
source.Provider_ID = RandStringBytesMaskImprSrcSB(5)
source.Owner_Tenant_ID = RandStringBytesMaskImprSrcSB(5)
source.User_ID = RandStringBytesMaskImprSrcSB(5)
source.Orig_User = RandStringBytesMaskImprSrcSB(5)
source.Username = RandStringBytesMaskImprSrcSB(5)
source.Keywords = RandStringBytesMaskImprSrcSB(5)
}
// const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const letterBytes = "abcdefghij" // limiting the combination of characters
const (
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)
// We don't want random data that changes from run to run
var src = rand.NewSource(4242)
// RandStringBytesMaskImprSrcSB from https://stackoverflow.com/a/31832326
func RandStringBytesMaskImprSrcSB(n int) string {
sb := strings.Builder{}
sb.Grow(n)
// A src.Int63() generates 63 random bits, enough for letterIdxMax characters!
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
sb.WriteByte(letterBytes[idx])
i--
}
cache >>= letterIdxBits
remain--
}
return sb.String()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearch
import (
"context"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"go.uber.org/fx"
)
// Tracer will provide some opensearch.HookContainer to provide tracing
type Tracer struct {
tracer opentracing.Tracer
}
func (t *Tracer) Before(ctx context.Context, before BeforeContext) context.Context {
if t.tracer == nil {
return ctx
}
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCClientEnum),
tracing.SpanTag("command", before.CommandType()),
}
ctx = tracing.WithTracer(t.tracer).
WithOpName("opensearch " + before.CommandType().String()).
WithOptions(opts...).
DescendantOrNoSpan(ctx)
return ctx
}
func (t *Tracer) After(ctx context.Context, afterContext AfterContext) context.Context {
if t.tracer == nil {
return ctx
}
op := tracing.WithTracer(t.tracer)
if (afterContext.Resp) != nil && (afterContext.Resp).IsError() {
op = op.WithOptions(
tracing.SpanTag("status code", (afterContext.Resp).StatusCode),
)
} else if *afterContext.Err != nil {
op = op.WithOptions(
tracing.SpanTag("error", afterContext.Err),
)
} else {
if afterContext.CommandType() == CmdSearch {
resp, err := UnmarshalResponse[SearchResponse[any]](afterContext.Resp)
if err != nil {
logger.Errorf("unable to unmarshal error: %v", err)
} else {
op = op.WithOptions(
tracing.SpanTag("hits", resp.Hits.Total.Value),
tracing.SpanTag("maxscore", resp.Hits.MaxScore),
)
}
}
}
ctx = op.FinishAndRewind(ctx)
return ctx
}
func TracerHook(tracer opentracing.Tracer) *Tracer {
o := Tracer{
tracer: tracer,
}
return &o
}
type tracingDI struct {
fx.In
Tracer opentracing.Tracer `optional:"true"`
}
func tracingProvider() fx.Annotated {
return fx.Annotated{
Group: FxGroup,
Target: func(di tracingDI) (BeforeHook, AfterHook) {
tracerHook := TracerHook(di.Tracer)
return tracerHook, tracerHook
},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package monitor
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/google/uuid"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/process"
"os"
"runtime"
"runtime/pprof"
"sync"
"time"
)
var (
SamplingRate = time.Second
SampleMaxSize int64 = 86400
SampleProfiles = utils.NewStringSet("block", "goroutine", "heap", "mutex", "threadcreate")
)
type dataCollector struct {
mtx sync.RWMutex
storage DataStorage
process *process.Process
prevSysTime float64
prevUserTime float64
prevNumGC uint32
// Mutex protected fields
ticker *time.Ticker
canceller context.CancelFunc
subscribers map[string]chan Feed
}
func NewDataCollector(storage DataStorage) *dataCollector {
proc, e := process.NewProcess(int32(os.Getpid()))
if e != nil {
panic(e)
}
return &dataCollector{
storage: storage,
process: proc,
}
}
func (c *dataCollector) Start(ctx context.Context) {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.ticker != nil {
return
}
var cancelCtx context.Context
cancelCtx, c.canceller = context.WithCancel(ctx)
c.ticker = time.NewTicker(SamplingRate)
c.subscribers = map[string]chan Feed{}
go c.collectFunc(cancelCtx, c.ticker)()
}
func (c *dataCollector) Stop() {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.ticker == nil {
return
}
if c.canceller != nil {
c.canceller()
c.canceller = nil
}
for _, ch := range c.subscribers {
close(ch)
}
c.subscribers = nil
c.ticker.Stop()
c.ticker = nil
}
func (c *dataCollector) Subscribe() (<-chan Feed, string, error) {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.ticker == nil {
return nil, "", fmt.Errorf("cannot subscribe before data collector started")
}
id := uuid.New().String()
ch := make(chan Feed)
c.subscribers[id] = ch
return ch, id, nil
}
func (c *dataCollector) Unsubscribe(id string) {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.subscribers == nil {
return
}
if ch, ok := c.subscribers[id]; ok {
close(ch)
delete(c.subscribers, id)
}
}
func (c *dataCollector) collectFunc(ctx context.Context, ticker *time.Ticker) func() {
return func() {
LOOP:
for {
select {
case now := <-ticker.C:
c.collect(ctx, now)
case <-ctx.Done():
break LOOP
}
}
c.Stop()
}
}
func (c *dataCollector) collect(ctx context.Context, now time.Time) {
// collect facts
timestamp := uint64(now.Unix()) * 1000
cpuTimes, e := c.process.Times()
var ms runtime.MemStats
runtime.ReadMemStats(&ms)
// pprof data
profiles := make(map[string]*pprof.Profile)
for _, p := range pprof.Profiles() {
if SampleProfiles.Has(p.Name()) {
profiles[p.Name()] = p
}
}
gPprof := &PprofPair{
Ts: timestamp,
Block: profiles["block"].Count(),
Goroutine: profiles["goroutine"].Count(),
Heap: profiles["heap"].Count(),
Mutex: profiles["mutex"].Count(),
Threadcreate: profiles["threadcreate"].Count(),
}
// CPU usage data
if e != nil {
cpuTimes = &cpu.TimesStat{}
}
gCpu := &CPUPair{
Ts: timestamp,
User: cpuTimes.User - c.prevUserTime,
Sys: cpuTimes.System - c.prevSysTime,
}
c.prevUserTime = cpuTimes.User
c.prevSysTime = cpuTimes.System
// memory data
gMemAlloc := &SimplePair{
Ts: timestamp,
Value: ms.Alloc,
}
// GC data
var gGCPause *SimplePair
var gcPause uint64
if c.prevNumGC == 0 || c.prevNumGC != ms.NumGC {
gcPause = ms.PauseNs[(ms.NumGC+255)%256]
gGCPause = &SimplePair{
Ts: timestamp,
Value: gcPause,
}
c.prevNumGC = ms.NumGC
}
// Create data
data := map[DataGroup]interface{}{
GroupPprof: gPprof,
GroupGCPauses: gGCPause,
GroupCPUUsage: gCpu,
GroupBytesAllocated: gMemAlloc,
}
if gGCPause == nil {
data[GroupGCPauses] = nil
}
// Create feed
feed := Feed{
Ts: timestamp,
BytesAllocated: gMemAlloc.Value,
GcPause: gcPause,
CPUUser: gCpu.User,
CPUSys: gCpu.Sys,
Block: gPprof.Block,
Goroutine: gPprof.Goroutine,
Heap: gPprof.Heap,
Mutex: gPprof.Mutex,
Threadcreate: gPprof.Threadcreate,
}
// Save and broadcast
if e := c.storage.AppendAll(ctx, data, SampleMaxSize); e != nil {
logger.Debugf("Failed to save profiling data: %v", e)
}
c.mtx.RLock()
defer c.mtx.RUnlock()
for _, ch := range c.subscribers {
ch <- feed
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package monitor
import (
"context"
"encoding/json"
"strings"
)
const (
GroupBytesAllocated DataGroup = "BytesAllocated"
GroupGCPauses DataGroup = "GcPauses"
GroupCPUUsage DataGroup = "CPUUsage"
GroupPprof DataGroup = "Pprof"
)
type DataGroup string
type DataStorage interface {
Read(ctx context.Context, groups ...DataGroup) (map[DataGroup]RawEntries, error)
// Append save data entry
Append(ctx context.Context, group DataGroup, entry interface{}, cap int64) error
// AppendAll save all data entries, grouped by DataGroup
AppendAll(ctx context.Context, data map[DataGroup]interface{}, cap int64) error
}
type Feed struct {
Ts uint64
BytesAllocated uint64
GcPause uint64
CPUUser float64
CPUSys float64
Block int
Goroutine int
Heap int
Mutex int
Threadcreate int
}
type RawEntries []string
func (v RawEntries) MarshalJSON() (data []byte, err error) {
if v == nil {
return []byte("null"), nil
}
return []byte("[" + strings.Join(v, ",") + "]"), nil
}
type SimplePair struct {
Ts uint64 `json:"Ts"`
Value uint64 `json:"Value"`
}
func (p *SimplePair) MarshalBinary() (data []byte, err error) {
return json.Marshal(p)
}
type CPUPair struct {
Ts uint64 `json:"Ts"`
User float64 `json:"User"`
Sys float64 `json:"Sys"`
}
func (p *CPUPair) MarshalBinary() (data []byte, err error) {
return json.Marshal(p)
}
type PprofPair struct {
Ts uint64 `json:"Ts"`
Block int `json:"Block"`
Goroutine int `json:"Goroutine"`
Heap int `json:"Heap"`
Mutex int `json:"Mutex"`
Threadcreate int `json:"Threadcreate"`
}
func (p *PprofPair) MarshalBinary() (data []byte, err error) {
return json.Marshal(p)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package monitor
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/profiler"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/assets"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"time"
)
const (
HttpChartPrefix = "charts"
)
var (
// DataGroups collected profiles to return in "data" endpoint
DataGroups = []DataGroup{GroupBytesAllocated, GroupGCPauses, GroupCPUUsage, GroupPprof}
// PongTimeout is the time data feed would wait for pong message before it close the websocket connection
PongTimeout = time.Minute
errWSWriterNotAvailable = errors.New("WebSocket writer not available")
)
type ChartsForwardRequest struct {
Path string `uri:"path"`
}
type ChartController struct {
storage DataStorage
collector *dataCollector
upgrader *websocket.Upgrader
}
func NewChartController(storage DataStorage, collector *dataCollector) *ChartController {
return &ChartController{
storage: storage,
collector: collector,
upgrader: &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
}
}
func (c *ChartController) Mappings() []web.Mapping {
return []web.Mapping{
assets.New(fmt.Sprintf("%s/%s/%s", profiler.RouteGroup, HttpChartPrefix, "static"), "static/"),
web.NewSimpleMapping("chart_ui", profiler.RouteGroup, HttpChartPrefix+"/", http.MethodGet, nil, c.ChartUI),
web.NewSimpleGinMapping("chart_data", profiler.RouteGroup, HttpChartPrefix+"/data", http.MethodGet, nil, c.Data),
web.NewSimpleGinMapping("chart_feed", profiler.RouteGroup, HttpChartPrefix+"/data-feed", http.MethodGet, nil, c.DataFeed),
}
}
func (c *ChartController) ChartUI(w http.ResponseWriter, r *http.Request) {
fs := http.FS(Content)
file, err := fs.Open("static/index.html")
if err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
fileInfo, err := file.Stat()
if err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
http.ServeContent(w, r, fileInfo.Name(), fileInfo.ModTime(), file)
}
func (c *ChartController) Data(gc *gin.Context) {
callback := gc.Query("callback")
gc.Header("Content-Type", "application/json")
data, e := c.storage.Read(gc.Request.Context(), DataGroups...)
if e != nil {
c.handleError(gc, e)
return
}
if _, e := fmt.Fprintf(gc.Writer, "%v(", callback); e != nil {
c.handleError(gc, e)
return
}
encoder := json.NewEncoder(gc.Writer)
if e := encoder.Encode(data); e != nil {
c.handleError(gc, e)
return
}
if _, e := fmt.Fprint(gc.Writer, ")"); e != nil {
c.handleError(gc, e)
return
}
}
func (c *ChartController) DataFeed(gc *gin.Context) {
// Subscribe data collector
ch, id, e := c.collector.Subscribe()
defer c.collector.Unsubscribe(id)
if e != nil {
c.handleError(gc, e)
return
}
// Upgrade to websocket connection
ws, e := c.upgrader.Upgrade(gc.Writer, gc.Request, nil)
if e != nil {
c.handleError(gc, e)
return
}
defer func() {
_ = ws.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "connection closed by server"), time.Now().Add(time.Second))
_ = ws.Close()
}()
// read and discard all messages
go c.wsReadSink(gc.Request.Context())(ws)
// write feed when received from collector, and send PingMessage every 10 feeds
var lastPing, lastPong time.Time
ws.SetPongHandler(func(s string) error {
lastPong = time.Now()
return nil
})
LOOP:
for i := uint(0); true; i++ {
select {
case feed := <-ch:
switch e := c.wsWriteJson(ws, &feed); {
case errors.Is(errWSWriterNotAvailable, e):
break LOOP
}
case <-gc.Request.Context().Done():
break LOOP
}
if i%10 != 0 {
continue
}
i = 0
// If we didn't receive Pong after PongTimeout, we quit loop and close connection
if lastPing.Sub(lastPong) > PongTimeout {
logger.WithContext(gc.Request.Context()).Debugf("No 'pong' message received after %v, closing connection...", PongTimeout)
break LOOP
}
// Ping
lastPing = time.Now()
if e := ws.WriteControl(websocket.PingMessage, nil, lastPing.Add(time.Second)); e != nil {
break LOOP
}
}
}
func (c *ChartController) handleError(gc *gin.Context, e error) {
gc.AbortWithStatusJSON(http.StatusInternalServerError, map[string]interface{}{
"error": e.Error(),
})
}
func (c *ChartController) wsWriteJson(ws *websocket.Conn, v interface{}) error {
switch w, e := ws.NextWriter(websocket.TextMessage); {
case e != nil:
return errWSWriterNotAvailable
default:
return json.NewEncoder(w).Encode(v)
}
}
func (c *ChartController) wsReadSink(ctx context.Context) func(ws *websocket.Conn) {
return func(ws *websocket.Conn) {
LOOP:
for e := error(nil); e == nil; _, _, e = ws.NextReader() {
select {
case <-ctx.Done():
break LOOP
default:
}
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package monitor
import (
"context"
"embed"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
)
var logger = log.New("PProf.Charts")
//go:embed static/*
var Content embed.FS
var Module = &bootstrap.Module{
Precedence: bootstrap.DebugPrecedence,
Options: []fx.Option{
fx.Provide(provideDataStorage, NewDataCollector),
fx.Invoke(initialize),
},
}
// Use Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
}
type storageDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
RedisFactory redis.ClientFactory `optional:"true"`
}
func provideDataStorage(di storageDI) DataStorage {
if di.RedisFactory != nil {
return NewRedisDataStorage(di.AppCtx, di.RedisFactory)
}
return nil // TODO: in-memory storage as fallback
}
type initDI struct {
fx.In
LC fx.Lifecycle
AppCtx *bootstrap.ApplicationContext
Registrar *web.Registrar `optional:"true"`
Collector *dataCollector
}
func initialize(di initDI) {
if di.Registrar != nil {
di.Registrar.MustRegister(Content)
di.Registrar.MustRegister(NewChartController(di.Collector.storage, di.Collector))
}
di.Collector.Start(di.AppCtx)
di.LC.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
di.Collector.Stop()
return nil
},
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package monitor
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/utils"
goRedis "github.com/go-redis/redis/v8"
"time"
)
const (
redisDB = 6
prefixData = "D"
ttl = 5 * time.Second
)
type redisDataStorage struct {
client redis.Client
identifier string
}
func NewRedisDataStorage(ctx context.Context, cf redis.ClientFactory) *redisDataStorage {
client, e := cf.New(ctx, func(opt *redis.ClientOption) {
opt.DbIndex = redisDB
})
if e != nil {
panic(e)
}
return &redisDataStorage{
client: client,
identifier: utils.RandomString(10),
}
}
func (s *redisDataStorage) Read(ctx context.Context, groups...DataGroup) (map[DataGroup]RawEntries, error) {
pipeliner := s.client.Pipeline()
defer func() { _ = pipeliner.Close()}()
for _, group := range groups {
key := s.groupKey(group)
pipeliner.LRange(ctx, key, 0, -1)
}
cmds, e := pipeliner.Exec(ctx)
if e != nil {
return nil, e
}
data := map[DataGroup]RawEntries{}
for i, cmd := range cmds {
if cmd.Err() != nil {
return nil, fmt.Errorf("%s failed: %v", cmd.Name(), cmd.Err())
}
switch result := cmd.(type) {
case *goRedis.StringSliceCmd:
// Note: redis store our data in a reversed order
data[groups[i]] = s.reverse(result.Val())
}
}
return data, nil
}
func (s *redisDataStorage) AppendAll(ctx context.Context, data map[DataGroup]interface{}, cap int64) error {
pipeliner := s.client.Pipeline()
defer func() { _ = pipeliner.Close()}()
for group, entry := range data {
key := s.groupKey(group)
if entry != nil {
pipeliner.LPush(ctx, key, entry)
pipeliner.LTrim(ctx, key, 0, cap - 1)
}
pipeliner.Expire(ctx, key, ttl)
}
cmds, e := pipeliner.Exec(ctx)
if e != nil {
return e
}
for _, cmd := range cmds {
if cmd.Err() != nil {
return fmt.Errorf("%s failed: %v", cmd.Name(), cmd.Err())
}
}
return e
}
func (s *redisDataStorage) Append(ctx context.Context, group DataGroup, entry interface{}, cap int64) error {
key := s.groupKey(group)
// Note: we ignore Redis errors from each command
_, e := s.client.Pipelined(ctx, func(pipeliner goRedis.Pipeliner) error {
if entry != nil {
pipeliner.LPush(ctx, key, entry)
pipeliner.LTrim(ctx, key, 0, cap - 1)
}
pipeliner.Expire(ctx, key, ttl)
return nil
})
return e
}
func (s *redisDataStorage) groupKey(group DataGroup) string {
return fmt.Sprintf(`%s:%s:%s`, prefixData, s.identifier, group)
}
func (s *redisDataStorage) reverse(data []string) []string {
size := len(data)
for i := 0; i < size / 2; i++ {
j := size - i - 1
v := data[i]
data[i] = data[j]
data[j] = v
}
return data
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package profiler
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
)
const (
RouteGroup = "debug"
PathPrefixPProf = "pprof"
)
var Module = &bootstrap.Module{
Precedence: bootstrap.DebugPrecedence,
Options: []fx.Option{
fx.Invoke(initialize),
},
}
// Use Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
Lifecycle fx.Lifecycle
WebRegistrar *web.Registrar `optional:"true"`
}
func initialize(di initDI) {
if di.WebRegistrar == nil {
return
}
di.WebRegistrar.MustRegister(&PProfController{})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package profiler
import (
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"net/http/pprof"
)
type PProfRequest struct {
Profile string `uri:"profile"`
}
type PProfController struct{}
func (c *PProfController) Mappings() []web.Mapping {
return []web.Mapping{
web.NewSimpleGinMapping("pprof_gin", RouteGroup, PathPrefixPProf + "/:profile", web.MethodAny, nil, c.Profile),
web.NewSimpleMapping("pprof_index", RouteGroup, PathPrefixPProf , web.MethodAny, nil, pprof.Index),
web.NewSimpleMapping("pprof_cli", RouteGroup, PathPrefixPProf + "/cmdline", web.MethodAny, nil, pprof.Cmdline),
web.NewSimpleMapping("pprof_profile", RouteGroup, PathPrefixPProf + "/profile", web.MethodAny, nil, pprof.Profile),
web.NewSimpleMapping("pprof_symbol", RouteGroup, PathPrefixPProf + "/symbol", web.MethodAny, nil, pprof.Symbol),
web.NewSimpleMapping("pprof_trace", RouteGroup, PathPrefixPProf + "/trace", web.MethodAny, nil, pprof.Trace),
}
}
func (c *PProfController) Profile(gc *gin.Context) {
var req PProfRequest
if e := gc.BindUri(&req); e != nil {
pprof.Index(gc.Writer, gc.Request)
return
}
handler := pprof.Handler(req.Profile)
if handler == nil {
pprof.Index(gc.Writer, gc.Request)
return
}
handler.ServeHTTP(gc.Writer, gc.Request)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package redis
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/go-redis/redis/v8"
"github.com/pkg/errors"
)
// KeepTTL is an option for Set command to keep key's existing TTL.
// For example:
//
// rdb.Set(ctx, key, value, redis.KeepTTL)
const KeepTTL = redis.KeepTTL
// ConnOptions options for connectivity by manipulating redis.UniversalOptions
type ConnOptions func(opt *redis.UniversalOptions) error
func GetUniversalOptions(p *RedisProperties, opts ...ConnOptions) (*redis.UniversalOptions, error) {
universal := &redis.UniversalOptions{
Addrs: p.Addresses,
DB: p.DB,
Username: p.Username,
Password: p.Password,
MaxRetries: p.MaxRetries,
MinRetryBackoff: p.MinRetryBackoff,
MaxRetryBackoff: p.MaxRetryBackoff,
DialTimeout: p.DialTimeout,
ReadTimeout: p.ReadTimeout,
WriteTimeout: p.WriteTimeout,
PoolSize: p.PoolSize,
MinIdleConns: p.MinIdleConns,
MaxConnAge: p.MaxConnAge,
PoolTimeout: p.PoolTimeout,
IdleTimeout: p.IdleTimeout,
IdleCheckFrequency: p.IdleCheckFrequency,
// Only cluster clients.
MaxRedirects: p.MaxRedirects,
ReadOnly: p.ReadOnly,
RouteByLatency: p.RouteByLatency,
RouteRandomly: p.RouteRandomly,
// The sentinel master name.
// Only failover clients.
MasterName: p.MasterName,
SentinelPassword: p.SentinelPassword,
}
for _, fn := range opts {
if e := fn(universal); e != nil {
return nil, e
}
}
return universal, nil
}
func withDB(dbIndex int) ConnOptions {
return func(opt *redis.UniversalOptions) error {
opt.DB = dbIndex
return nil
}
}
func withTLS(ctx context.Context, certsMgr certs.Manager, p *certs.SourceProperties) ConnOptions {
return func(opt *redis.UniversalOptions) error {
if certsMgr == nil {
return fmt.Errorf("TLS auth is enabled for Redis, but certificate manager is not available")
}
src, err := certsMgr.Source(ctx, certs.WithSourceProperties(p))
if err != nil {
return errors.Wrapf(err, "failed to initialize redis connection: %v", err)
}
opt.TLSConfig, err = src.TLSConfig(ctx)
if err != nil {
return errors.Wrapf(err, "failed to initialize redis connection: %v", err)
}
return nil
}
}
type Client interface {
redis.UniversalClient
}
type client struct {
redis.UniversalClient
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package redis
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/go-redis/redis/v8"
"github.com/pkg/errors"
)
type ClientOptions func(opt *ClientOption)
type ClientOption struct {
DbIndex int
}
type OptionsAwareHook interface {
redis.Hook
WithClientOption(*redis.UniversalOptions) redis.Hook
}
type ClientFactory interface {
// New returns a newly created Client
New(ctx context.Context, opts ...ClientOptions) (Client, error)
// AddHooks add hooks to all Client already created and any future Client created via this interface
// If the given hook also implments OptionsAwareHook, the method will be used to derive a hook instance and added to
// coresponding client
AddHooks(ctx context.Context, hooks ...redis.Hook)
}
// clientFactory implements ClientFactory
type clientRecord struct {
client Client
options *redis.UniversalOptions
}
type clientFactory struct {
properties RedisProperties
hooks []redis.Hook
clients map[ClientOption]clientRecord
certsManager certs.Manager
}
type FactoryOptions func(opt *FactoryOption)
type FactoryOption struct {
Properties RedisProperties
TLSCertsManager certs.Manager
}
func NewClientFactory(opts...FactoryOptions) ClientFactory {
opt := FactoryOption{}
for _, fn := range opts {
fn(&opt)
}
return &clientFactory{
properties: opt.Properties,
hooks: []redis.Hook{},
clients: map[ClientOption]clientRecord{},
certsManager: opt.TLSCertsManager,
}
}
func (f *clientFactory) New(ctx context.Context, opts ...ClientOptions) (Client, error) {
opt := ClientOption{}
for _, f := range opts {
f(&opt)
}
// Some validations
if opt.DbIndex < 0 || opt.DbIndex >= 16 {
return nil, fmt.Errorf("invalid Redis DB index [%d]: must be between 0 and 16", opt.DbIndex)
}
if existing, ok := f.clients[opt]; ok {
return existing.client, nil
}
connOpts := []ConnOptions{withDB(opt.DbIndex)}
if f.properties.TLS.Enabled {
connOpts = append(connOpts, withTLS(ctx, f.certsManager, &f.properties.TLS.Certs))
}
// prepare options
options, e := GetUniversalOptions(&f.properties, connOpts...)
if e != nil {
return nil, errors.Wrap(e, "Invalid redis configuration")
}
c := client{
UniversalClient: redis.NewUniversalClient(options),
}
// apply hooks
for _, hook := range f.hooks {
h := hook
if aware, ok := hook.(OptionsAwareHook); ok {
h = aware.WithClientOption(options)
}
c.AddHook(h)
}
// record the client
f.clients[opt] = clientRecord{
client: c,
options: options,
}
logger.WithContext(ctx).Infof("Redis client created with DB index %d", options.DB)
return &c, nil
}
func (f *clientFactory) AddHooks(ctx context.Context, hooks ...redis.Hook) {
f.hooks = append(f.hooks, hooks...)
// add to existing clients
for _, hook := range hooks {
for _, record := range f.clients {
h := hook
if aware, ok := hook.(OptionsAwareHook); ok {
h = aware.WithClientOption(record.options)
}
record.client.AddHook(h)
}
}
logger.WithContext(ctx).Debugf("Added redis hooks: %v", hooks)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package redis
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"go.uber.org/fx"
)
type regDI struct {
fx.In
HealthRegistrar health.Registrar `optional:"true"`
RedisClient Client
}
func registerHealth(di regDI) {
if di.HealthRegistrar == nil {
return
}
di.HealthRegistrar.MustRegister(&RedisHealthIndicator{
client: di.RedisClient,
})
}
type RedisHealthIndicator struct {
client Client
}
func (i *RedisHealthIndicator) Name() string {
return "redis"
}
func (i *RedisHealthIndicator) Health(c context.Context, options health.Options) health.Health {
if _, e := i.client.Ping(c).Result(); e != nil {
logger.WithContext(c).Errorf("Health Ping to Redis failed: %s", e)
return health.NewDetailedHealth(health.StatusDown, "redis ping failed", nil)
} else {
return health.NewDetailedHealth(health.StatusUp, "redis ping succeeded", nil)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package redis
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/opentracing/opentracing-go"
"go.uber.org/fx"
)
var logger = log.New("Redis")
var Module = &bootstrap.Module{
Precedence: bootstrap.RedisPrecedence,
Options: []fx.Option{
fx.Provide(BindRedisProperties),
fx.Provide(provideClientFactory),
fx.Provide(provideDefaultClient),
fx.Invoke(registerHealth),
},
}
// Use Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
}
type factoryDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Props RedisProperties
CertManager certs.Manager `optional:"true"`
Tracer opentracing.Tracer `optional:"true"`
}
func provideClientFactory(di factoryDI) ClientFactory {
factory := NewClientFactory(func(opt *FactoryOption) {
opt.Properties = di.Props
opt.TLSCertsManager = di.CertManager
})
if di.Tracer != nil {
factory.AddHooks(di.AppCtx, NewRedisTrackingHook(di.Tracer))
}
return factory
}
type clientDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Factory ClientFactory
Properties RedisProperties
}
func provideDefaultClient(di clientDI) Client {
c, e := di.Factory.New(di.AppCtx, func(opt *ClientOption) {
opt.DbIndex = di.Properties.DB
})
if e != nil {
panic(e)
}
return c
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package redis
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/certs"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"time"
)
const (
ConfigRootRedisConnection = "redis"
DefaultDbIndex = 0
)
type RedisProperties struct {
// Either a single address or a seed list of host:port addresses
// of cluster/sentinel nodes.
Addresses utils.CommaSeparatedSlice `json:"addrs"`
// Database to be selected after connecting to the server.
// Only single-node and failover clients.
DB int `json:"db"`
// Common options.
Username string `json:"username"`
Password string `json:"password"`
MaxRetries int `json:"max-retries"`
MinRetryBackoff time.Duration `json:"min-retry-backoff"`
MaxRetryBackoff time.Duration `json:"max-retry-backoff"`
DialTimeout time.Duration `json:"dial-timeout"`
ReadTimeout time.Duration `json:"read-timeout"`
WriteTimeout time.Duration `json:"write-timeout"`
PoolSize int `json:"pool-size"`
MinIdleConns int `json:"min-idle-conns"`
MaxConnAge time.Duration `json:"max-conn-age"`
PoolTimeout time.Duration `json:"pool-timeout"`
IdleTimeout time.Duration `json:"idle-timeout"`
IdleCheckFrequency time.Duration `json:"idle-check-frequency"`
// TLS Properties for Redis
TLS TLSProperties `json:"tls"`
// Only cluster clients.
MaxRedirects int `json:"max-redirects"`
ReadOnly bool `json:"read-only"`
RouteByLatency bool `json:"route-by-latency"`
RouteRandomly bool `json:"route-randomly"`
// The sentinel master name.
// Only failover clients.
MasterName string `json:"master-name"`
SentinelPassword string `json:"sentinel-password"`
}
type TLSProperties struct {
Enabled bool `json:"enabled"`
Certs certs.SourceProperties `json:"certs"`
}
func BindRedisProperties(ctx *bootstrap.ApplicationContext) RedisProperties {
props := RedisProperties{}
if err := ctx.Config().Bind(&props, ConfigRootRedisConnection); err != nil {
panic(errors.Wrap(err, "failed to bind redis.RedisProperties"))
}
return props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package redis
import (
"context"
"github.com/cisco-open/go-lanai/pkg/tracing"
goredis "github.com/go-redis/redis/v8"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
const tracingOpName = "redis"
// redisTracingHook implements redis.Hook and redis.OptionsAwareHook
type redisTracingHook struct {
tracer opentracing.Tracer
db int
}
func NewRedisTrackingHook(tracer opentracing.Tracer) *redisTracingHook{
return newRedisTrackingHook(tracer, -1)
}
func newRedisTrackingHook(tracer opentracing.Tracer, db int) *redisTracingHook{
return &redisTracingHook{
tracer: tracer,
db: db,
}
}
// WithClientOption implements redis.OptionsAwareHook
func (h redisTracingHook) WithClientOption(opts *goredis.UniversalOptions) goredis.Hook {
return newRedisTrackingHook(h.tracer, opts.DB)
}
// BeforeProcess implements redis.Hook
func (h redisTracingHook) BeforeProcess(ctx context.Context, cmd goredis.Cmder) (context.Context, error) {
name := tracingOpName + " " + cmd.Name()
cmdStr := cmd.Name()
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCClientEnum),
tracing.SpanTag("cmd", cmdStr),
}
if h.db >= 0 {
opts = append(opts, tracing.SpanTag("db", h.db))
}
return tracing.WithTracer(h.tracer).
WithOpName(name).
WithOptions(opts...).
DescendantOrNoSpan(ctx), nil
}
// AfterProcess implements redis.Hook
func (h redisTracingHook) AfterProcess(ctx context.Context, cmd goredis.Cmder) error {
op := tracing.WithTracer(h.tracer)
if cmd.Err() != nil {
op.WithOptions(tracing.SpanTag("err", cmd.Err()))
}
op.Finish(ctx)
return nil
}
// BeforeProcessPipeline implements redis.Hook
func (h redisTracingHook) BeforeProcessPipeline(ctx context.Context, cmds []goredis.Cmder) (context.Context, error) {
name := tracingOpName + "-batch"
cmdNames := make([]string, len(cmds))
for i, v := range cmds {
cmdNames[i] = v.Name()
}
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCClientEnum),
tracing.SpanTag("cmd", cmdNames),
}
if h.db >= 0 {
opts = append(opts, tracing.SpanTag("data", h.db))
}
return tracing.WithTracer(h.tracer).
WithOpName(name).
WithOptions(opts...).
DescendantOrNoSpan(ctx), nil
}
// AfterProcessPipeline implements redis.Hook
func (h redisTracingHook) AfterProcessPipeline(ctx context.Context, cmds []goredis.Cmder) error {
op := tracing.WithTracer(h.tracer)
errs := map[string]error{}
for _, v := range cmds {
if v.Err() != nil {
errs[v.Name()] = v.Err()
}
}
if len(errs) != 0 {
op.WithOptions(tracing.SpanTag("err", errs))
}
op.Finish(ctx)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scheduler
import (
"github.com/robfig/cron/v3"
)
var cronOptions = cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.DowOptional
// Cron schedules a task using CRON expression
// Supported CRON expression is "<second> <minutes> <hours> <day of month> <month> [day of week]",
// where "day of week" is optional
// Note 1: do not support 'L'
// Note 2: any options affecting start time and repeat rate (StartAt, AtRate, etc.) would take no effect
func Cron(expr string, taskFunc TaskFunc, opts ...TaskOptions) (TaskCanceller, error) {
opts = append([]TaskOptions{TaskHooks(defaultTaskHooks...)}, opts...)
opts = append(opts, withCronExpression(expr))
return newTask(taskFunc, opts...)
}
func withCronExpression(expr string) TaskOptions {
return func(opt *TaskOption) error {
nextFn, e := cronNextFunc(expr)
if e != nil {
return e
}
return dynamicNext(nextFn)(opt)
}
}
func cronNextFunc(expr string) (nextFunc, error) {
schedule, e := cron.NewParser(cronOptions).Parse(expr)
if e != nil {
return nil, e
}
return schedule.Next, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scheduler
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/opentracing/opentracing-go"
"time"
)
var logger = log.New("Scheduler")
var defaultTaskHooks []TaskHook
/**************************
Scheduling Functions
**************************/
// Repeat schedules a takes that repeat at specified time
func Repeat(taskFunc TaskFunc, opts ...TaskOptions) (TaskCanceller, error) {
opts = append([]TaskOptions{TaskHooks(defaultTaskHooks...)}, opts...)
return newTask(taskFunc, opts...)
}
// RunOnce schedules a task that run only once at specified time
// Note: any options affecting repeat rate (AtRate, WithDelay, etc.) would take no effect
func RunOnce(taskFunc TaskFunc, opts ...TaskOptions) (TaskCanceller, error) {
opts = append([]TaskOptions{TaskHooks(defaultTaskHooks...)}, opts...)
opts = append(opts, runOnceOption())
return newTask(taskFunc, opts...)
}
func AddDefaultHook(hooks ...TaskHook) {
defaultTaskHooks = append(defaultTaskHooks, hooks...)
order.SortStable(defaultTaskHooks, order.OrderedFirstCompare)
}
// EnableTracing add a default hook with provided openstracing.Tracer start/end/propagate spans during execution
func EnableTracing(tracer opentracing.Tracer) {
if tracer != nil {
AddDefaultHook(newTracingTaskHook(tracer))
}
}
/**************************
Options
**************************/
// Name option to give the task a name
func Name(name string) TaskOptions {
return func(opt *TaskOption) error {
opt.name = name
return nil
}
}
// TaskHooks option to add TaskHook
func TaskHooks(hooks ...TaskHook) TaskOptions {
return func(opt *TaskOption) error {
opt.hooks = append(opt.hooks, hooks...)
return nil
}
}
// StartAt option to set task's initial trigger time, should be future time
// Exclusive with StartAfter
func StartAt(startTime time.Time) TaskOptions {
return func(opt *TaskOption) error {
opt.initialTime = startTime
return nil
}
}
// StartAfter option to set task's initial trigger delay, should be positive duration
// Exclusive with StartAt
func StartAfter(delay time.Duration) TaskOptions {
return func(opt *TaskOption) error {
if delay < 0 {
return fmt.Errorf("StartAfter doesn't support negative value")
}
return StartAt(time.Now().Add(delay))(opt)
}
}
// AtRate option for "Fixed Interval" mode. Triggered every given interval.
// Long-running tasks overlap each other.
// Exclusive with WithDelay
func AtRate(repeatInterval time.Duration) TaskOptions {
return func(opt *TaskOption) error {
opt.mode = ModeFixedRate
opt.interval = repeatInterval
return nil
}
}
// WithDelay option for "Fixed Delay" mode. Triggered with given delay after previous task finished
// Long-running tasks will never overlap
// Exclusive with AtRate
func WithDelay(repeatDelay time.Duration) TaskOptions {
return func(opt *TaskOption) error {
opt.mode = ModeFixedDelay
opt.interval = repeatDelay
return nil
}
}
// CancelOnError option that automatically cancel the scheduled task if any execution returns non-nil error
func CancelOnError() TaskOptions {
return func(opt *TaskOption) error {
opt.cancelOnError = true
return nil
}
}
/**************************
Helpers
**************************/
func runOnceOption() TaskOptions {
return func(opt *TaskOption) error {
opt.mode = ModeRunOnce
return nil
}
}
func dynamicNext(nextFn nextFunc) TaskOptions {
return func(opt *TaskOption) error {
opt.mode = ModeDynamic
opt.nextFunc = nextFn
return nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scheduler
import (
"context"
"fmt"
"github.com/google/uuid"
"sync"
"time"
)
// task execute TaskFunc based on TaskOption. Also implements TaskCanceller
type task struct {
mtx sync.Mutex
id string
task TaskFunc
option TaskOption
cancel context.CancelFunc
done chan error
err error
}
func newTask(taskFunc TaskFunc, opts ...TaskOptions) (TaskCanceller, error) {
if taskFunc == nil {
return nil, fmt.Errorf("task function cannot be nil")
}
id := uuid.New().String()
t := task{
id: id,
task: taskFunc,
done: make(chan error, 1),
}
for _, fn := range opts {
if e := fn(&t.option); e != nil {
return nil, e
}
}
if t.option.name != "" {
t.id = fmt.Sprintf("%s-%s", t.option.name, id)
}
switch {
case t.option.mode != ModeRunOnce && t.option.mode != ModeDynamic && t.option.interval <= 0:
return nil, fmt.Errorf("repeated task should have positive repeat interval")
}
// start and return
t.start(context.Background())
return &t, nil
}
// Cancel implements TaskCanceller
func (t *task) Cancel() {
t.mtx.Lock()
defer t.mtx.Unlock()
t.err = context.Canceled
t.cancel()
}
// Cancelled implements TaskCanceller
func (t *task) Cancelled() <-chan error {
return t.done
}
// start main loop
func (t *task) start(ctx context.Context) {
taskCtx, fn := context.WithCancel(ctx)
t.cancel = fn
go t.loop(taskCtx)
}
// loop is the main loop for the task
func (t *task) loop(ctx context.Context) {
defer func() {
t.mtx.Lock()
defer t.mtx.Unlock()
t.done <- t.err
close(t.done)
}()
// first, figure out first fire time if set
var delay time.Duration
switch {
case t.option.mode == ModeDynamic:
delay = time.Until(t.option.nextFunc(time.Now()))
case !t.option.initialTime.IsZero():
delay = time.Until(t.option.initialTime)
if delay < 0 {
if t.option.mode == ModeFixedRate {
// adjust using interval (first positive trigger time)
delay = (t.option.interval + (delay % t.option.interval)) % t.option.interval
} else {
delay = 0
}
}
}
select {
case <-time.After(delay):
t.execTask(ctx, t.option.mode != ModeFixedRate && t.option.mode != ModeDynamic)
case <-ctx.Done():
return
}
// repeat if applicable
switch t.option.mode {
case ModeFixedRate:
t.fixedIntervalLoop(ctx)
case ModeFixedDelay:
t.fixedDelayLoop(ctx)
case ModeDynamic:
t.dynamicTriggerLoop(ctx)
case ModeRunOnce:
}
}
func (t *task) fixedIntervalLoop(ctx context.Context) {
ticker := time.NewTicker(t.option.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.execTask(ctx, false)
case <-ctx.Done():
return
}
}
}
func (t *task) fixedDelayLoop(ctx context.Context) {
timer := time.NewTimer(t.option.interval)
for {
select {
case <-timer.C:
t.execTask(ctx, true)
timer.Reset(t.option.interval)
case <-ctx.Done():
timer.Stop()
return
}
}
}
func (t *task) dynamicTriggerLoop(ctx context.Context) {
next := t.option.nextFunc(time.Now())
timer := time.NewTimer(time.Until(next))
for {
select {
case now := <-timer.C:
t.execTask(ctx, false)
next = t.option.nextFunc(now)
timer.Reset(time.Until(next))
case <-ctx.Done():
timer.Stop()
return
}
}
}
func (t *task) execTask(ctx context.Context, wait bool) {
errCh := make(chan error, 1)
go func() {
execCtx := ctx
var err error
defer func() {
// try recover
if e := recover(); e != nil {
err = fmt.Errorf("%v", e)
}
// post-hook
for _, hook := range t.option.hooks {
hook.AfterTrigger(execCtx, t.id, err)
}
// handle error
if err != nil {
t.handleError(execCtx, err)
}
// notify and cleanup
errCh <- err
close(errCh)
}()
// pre-hook
for _, hook := range t.option.hooks {
execCtx = hook.BeforeTrigger(execCtx, t.id)
}
// run task
err = t.task(execCtx)
}()
if !wait {
return
}
select {
case <-errCh:
}
}
func (t *task) handleError(ctx context.Context, err error) {
t.mtx.Lock()
defer t.mtx.Unlock()
t.err = err
if t.option.cancelOnError {
logger.WithContext(ctx).Infof("Task [%s] cancelled due to error: %v", t.id, err)
t.cancel()
} else {
logger.WithContext(ctx).Debugf("Task [%s] returned with error: %v", t.id, err)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package scheduler
import (
"context"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
const tracingOpName = "scheduler"
type tracingTaskHook struct {
tracer opentracing.Tracer
}
func newTracingTaskHook(tracer opentracing.Tracer) *tracingTaskHook {
return &tracingTaskHook{
tracer: tracer,
}
}
func (h *tracingTaskHook) BeforeTrigger(ctx context.Context, id string) context.Context {
name := tracingOpName
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCClientEnum),
tracing.SpanTag("task", id),
}
return tracing.WithTracer(h.tracer).
WithOpName(name).
WithOptions(opts...).
NewSpanOrFollows(ctx)
}
func (h *tracingTaskHook) AfterTrigger(ctx context.Context, _ string, err error) {
op := tracing.WithTracer(h.tracer)
if err != nil {
op.WithOptions(tracing.SpanTag("err", err))
}
op.Finish(ctx)
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package access
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web"
"net/http"
)
// DecisionMakerFunc determine if current user can access to given http.Request
// if the given request is not handled by this function, return false, nil
// if the given request is handled and the access is granted, return true, nil
// otherwise, return true, security.ErrorTypeCodeAccessControl error
type DecisionMakerFunc func(context.Context, *http.Request) (handled bool, decision error)
// AcrMatcher short for Access Control RequestDetails Matcher, accepts *http.Request or http.Request
type AcrMatcher web.RequestMatcher
// ControlFunc make access control decision based on security.Authentication
// "decision" indicate whether the access is grated
// "reason" is optional and is used when access is denied. if not specified, security.NewAccessDeniedError will be used
type ControlFunc func(security.Authentication) (decision bool, reason error)
func MakeDecisionMakerFunc(matcher AcrMatcher, cf ControlFunc) DecisionMakerFunc {
return func(ctx context.Context, r *http.Request) (bool, error) {
matches, err := matcher.MatchesWithContext(ctx, r)
if !matches || err != nil {
return false, err
}
auth := security.Get(ctx)
granted, reason := cf(auth)
switch {
case granted:
return true, nil
case reason != nil:
return true, reason
default:
return true, security.NewAccessDeniedError("access denied")
}
}
}
func WrapDecisionMakerFunc(matcher AcrMatcher, dmf DecisionMakerFunc) DecisionMakerFunc {
return func(ctx context.Context, r *http.Request) (bool, error) {
matches, err := matcher.MatchesWithContext(ctx, r)
if !matches || err != nil {
return false, err
}
return dmf(ctx, r)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package access
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
)
// ControlCondition extends web.RequestMatcher, and matcher.ChainableMatcher
// it is used together with web.RoutedMapping's "Condition" for a convienent config of securities
// only matcher.ChainableMatcher's .MatchesWithContext (context.Context, interface{}) (bool, error) is used
// Matches(interface{}) (bool, error) should return regular as if the context is empty
//
// In addition, implementation should also return AccessDeniedError when condition didn't match.
// web.Registrar will propagate this error along the handler chain until it's handled by errorhandling middleware
type ControlCondition matcher.ChainableMatcher
/**************************
Common Impl.
***************************/
// ConditionWithControlFunc is a common ControlCondition implementation backed by ControlFunc
type ConditionWithControlFunc struct {
Description string
ControlFunc ControlFunc
}
func (m *ConditionWithControlFunc) Matches(i interface{}) (bool, error) {
return m.MatchesWithContext(context.Background(), i)
}
func (m *ConditionWithControlFunc) MatchesWithContext(c context.Context, _ interface{}) (bool, error) {
auth := security.Get(c)
return m.ControlFunc(auth)
}
func (m *ConditionWithControlFunc) Or(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.Or(m, matchers...)
}
func (m *ConditionWithControlFunc) And(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.And(m, matchers...)
}
func (m ConditionWithControlFunc) String() string {
switch {
case len(m.Description) != 0:
return m.Description
default:
return "access.ControlCondition"
}
}
/**************************
Constructors
***************************/
// RequirePermissions returns ControlCondition using HasPermissionsWithExpr
// e.g. RequirePermissions("P1 && P2 && !(P3 || P4)"), means security.Permissions contains both P1 and P2 but not contains neither P3 nor P4
// see HasPermissionsWithExpr for expression syntax
func RequirePermissions(expr string) ControlCondition {
return &ConditionWithControlFunc{
Description: fmt.Sprintf("user's permissions match [%s]", expr),
ControlFunc: HasPermissionsWithExpr(expr),
}
}
/**************************
Helpers
***************************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package access
import (
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"sort"
)
var (
FeatureId = security.FeatureId("AC", security.FeatureOrderAccess)
)
//goland:noinspection GoNameStartsWithPackageName
type AccessControlConfigurer struct {
}
func newAccessControlConfigurer() *AccessControlConfigurer {
return &AccessControlConfigurer{
}
}
func (acc *AccessControlConfigurer) Apply(feature security.Feature, ws security.WebSecurity) error {
// Verify
if err := acc.validate(feature.(*AccessControlFeature), ws); err != nil {
return err
}
f := feature.(*AccessControlFeature)
// construct decision maker functions
decisionMakers := make([]DecisionMakerFunc, len(f.acl))
sort.SliceStable(f.acl, func(i, j int) bool {
return order.OrderedFirstCompare(f.acl[i], f.acl[j])
})
for i, ac := range f.acl {
if ac.custom != nil {
decisionMakers[i] = WrapDecisionMakerFunc(ac.matcher, ac.custom)
} else {
decisionMakers[i] = MakeDecisionMakerFunc(ac.matcher, ac.control)
}
}
// register middlewares
mw := NewAccessControlMiddleware(decisionMakers...)
ac := middleware.NewBuilder("access control").
Order(security.MWOrderAccessControl).
Use(mw.ACHandlerFunc())
ws.Add(ac)
return nil
}
func (acc *AccessControlConfigurer) validate(f *AccessControlFeature, ws security.WebSecurity) error {
if len(f.acl) == 0 {
logger.Infof("access control is not set, default to DenyAll - [%v]", log.Capped(ws, 80))
f.Request(matcher.AnyRequest()).DenyAll()
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package access
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
)
//goland:noinspection GoNameStartsWithPackageName
type AccessControl struct {
owner *AccessControlFeature
order int
matcher AcrMatcher
control ControlFunc
custom DecisionMakerFunc
}
// Order implements order.Ordered
func (ac *AccessControl) Order() int {
return ac.order
}
func (ac *AccessControl) WithOrder(order int) *AccessControl {
ac.order = order
return ac
}
func (ac *AccessControl) PermitAll() *AccessControlFeature {
ac.control = PermitAll
return ac.owner
}
func (ac *AccessControl) DenyAll() *AccessControlFeature {
ac.control = DenyAll
return ac.owner
}
func (ac *AccessControl) Authenticated() *AccessControlFeature {
ac.control = Authenticated
return ac.owner
}
func (ac *AccessControl) HasPermissions(permissions ...string) *AccessControlFeature {
ac.control = HasPermissions(permissions...)
return ac.owner
}
func (ac *AccessControl) AllowIf(cf ControlFunc) *AccessControlFeature {
ac.control = cf
return ac.owner
}
// CustomDecisionMaker override ControlFunc. Order and AcrMatcher are still applied
func (ac *AccessControl) CustomDecisionMaker(dmf DecisionMakerFunc) *AccessControlFeature {
ac.custom = dmf
return ac.owner
}
//goland:noinspection GoNameStartsWithPackageName
type AccessControlFeature struct {
acl []*AccessControl
}
// Identifier implements security.Feature
func (f *AccessControlFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
// Request configure access control of requests matching given AcrMatcher
func (f *AccessControlFeature) Request(matcher AcrMatcher) *AccessControl {
ac := &AccessControl{
owner: f,
matcher: matcher,
}
f.acl = append(f.acl, ac)
return ac
}
func Configure(ws security.WebSecurity) *AccessControlFeature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*AccessControlFeature)
}
panic(fmt.Errorf("unable to configure access control: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// New Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *AccessControlFeature {
return &AccessControlFeature{}
}
/**************************
Common ControlFunc
***************************/
func PermitAll(_ security.Authentication) (bool, error) {
return true, nil
}
func DenyAll(_ security.Authentication) (bool, error) {
return false, nil
}
func Authenticated(auth security.Authentication) (bool, error) {
if auth.State() >= security.StateAuthenticated {
return true, nil
}
return false, security.NewInsufficientAuthError("not fully authenticated")
}
// Note: More ControlFunc can be found in permissions.go
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package access
import (
"github.com/gin-gonic/gin"
)
//goland:noinspection GoNameStartsWithPackageName
type AccessControlMiddleware struct {
decisionMakers []DecisionMakerFunc
}
func NewAccessControlMiddleware(decisionMakers...DecisionMakerFunc) *AccessControlMiddleware {
return &AccessControlMiddleware{decisionMakers: decisionMakers}
}
func (ac *AccessControlMiddleware) ACHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
var err error
for _, decisionMaker := range ac.decisionMakers {
var handled bool
handled, err = decisionMaker(ctx, ctx.Request)
if handled {
break
}
}
if err != nil {
// access denied
ac.handleError(ctx, err)
} else {
ctx.Next()
}
}
}
func (ac *AccessControlMiddleware) handleError(c *gin.Context, err error) {
// We add the error and let the error handling middleware to render it
_ = c.Error(err)
c.Abort()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package access
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
var logger = log.New("SEC.Access")
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "access control",
Precedence: security.MinSecurityPrecedence + 30,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newAccessControlConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package access
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"strings"
)
/**************************
Common ControlFunc
***************************/
// HasPermissions returns a ControlFunc that checks permissions of current auth.
// If the given auth doesn't contain all specified permission, the ControlFunc returns false and a security.AccessDeniedError
func HasPermissions(permissions...string) ControlFunc {
return func(auth security.Authentication) (bool, error) {
switch {
case auth.State() > security.StateAnonymous && security.HasPermissions(auth, permissions...):
return true, nil
case auth.State() < security.StatePrincipalKnown:
return false, security.NewInsufficientAuthError("not authenticated")
case auth.State() < security.StateAuthenticated:
return false, security.NewInsufficientAuthError("not fully authenticated")
default:
return false, security.NewAccessDeniedError("access denied")
}
}
}
/**************************
Permission Expr
***************************/
const (
opAnd = "&&"
opOr = "||"
opNot = "!"
opOpen = "("
opClose = ")"
)
// HasPermissionsWithExpr takes an expression and returns a ControlFunc that evaluate security.Permissions against
// the given expression.
//
// The expression is composed by 1 or more expression-unit combined using logical operands and brackets.
// supported expresion-unit are:
// - !<permission>
// - <permission> && <permission>
// - <permission> || <permission>
// where <permission> stands for "security.Permissions.Has(<permission>)" which yields bool result
// e.g. "P1 && P2 && !(P3 || P4)", means security.Permissions contains both P1 and P2 but not contains neither P3 nor P4
func HasPermissionsWithExpr(expr string) ControlFunc {
matcher, e := parsePermissionExpr(expr)
if matcher == nil {
expr = strings.ReplaceAll(expr, " ", "")
panic(fmt.Errorf(`Invalid permission expression "%s": %v`, expr, e))
}
return func(auth security.Authentication) (bool, error) {
if auth.State() > security.StateAnonymous {
//user with API admin permission is allowed to short cut the permission check
if security.HasPermissions(auth, security.SpecialPermissionAPIAdmin) {
return true, nil
}
if match, e := matcher.Matches(auth.Permissions()); match && e == nil {
return true, nil
}
}
switch {
case auth.State() < security.StatePrincipalKnown:
return false, security.NewInsufficientAuthError("not authenticated")
case auth.State() < security.StateAuthenticated:
return false, security.NewInsufficientAuthError("not fully authenticated")
default:
return false, security.NewAccessDeniedError("access denied")
}
}
}
/**************************
Expr Parsing Helpers
***************************/
var opTokens = utils.NewStringSet(opAnd, opOr, opNot, opOpen, opClose)
type operand struct {
op string
order int
}
func operandFromString(str string) operand {
switch str {
case opOr:
return operand{op:opOr, order: 1}
case opAnd:
return operand{op:opAnd, order: 2}
case opNot:
return operand{op:opNot, order: 3}
default:
return operand{op:"", order: 0}
}
}
func (o operand) Precedence(p int) operand {
return operand{
op: o.op,
order: o.order + p,
}
}
func (o operand) String() string {
return o.op
}
// parsePermissionExpr takes an expression and parse it into a matcher.ChainableMatcher
// it parse the expr with helps of two FILO stacks ("ops" and "args"):
// - "ops" holds all operands to be processed ("!", "&&", "||"), each has a precedence/order representing evaluation priority
// - "args" holds permissions matchers to be processed
// - "args" will eventually be reduced into a single matcher.ChainableMatcher that represent the overall expression
// see processOperand for more details
func parsePermissionExpr(expr string) (ret matcher.Matcher, err error) {
expr = strings.ReplaceAll(expr, " ", "")
var ops []operand
var args []matcher.Matcher
lastToken := ""
var precedence, idx int
for remaining := expr; remaining != ""; {
t, r := nextToken(remaining)
if opTokens.Has(t) {
switch t {
case opOpen:
precedence = precedence + 10
case opClose:
if precedence == 0 {
return nil, fmt.Errorf(`found ")" without matching "(" at idx %d`, idx)
}
precedence = precedence - 10
case opOr, opAnd:
if opTokens.Has(lastToken) {
// we have && or || follows another operand, this is invalid. e.g. "A && || B"
return nil, fmt.Errorf(`found "&&" or "||" following another operand at idx %d`, idx)
}
fallthrough
default:
op := operandFromString(t).Precedence(precedence)
ops, args = processOperand(ops, args, op)
if op.op == "" || ops == nil {
return nil, fmt.Errorf(`unexpected error at idx %d`, idx)
}
}
} else {
if strings.ContainsAny(t, "&|!()") {
return nil, fmt.Errorf(`invalid permission value idx %d`, idx)
}
args = append(args, NewPermissionMatcher(t))
}
idx = idx + len(t)
remaining = r
}
if precedence != 0 {
// we don't have matching number of "(" and ")"
return nil, fmt.Errorf(`unexpected EOF, found "(" without matching ")"`)
}
ops, args = processOperand(ops, args, operand{})
if len(ops) != 1 || len(args) != 1 {
return nil, fmt.Errorf(`unexpected EOF, unknown error`)
}
ret = args[0]
return
}
// processOperand takes the new operand and existing operand stack and args stack, and returns the processed stacks
// using rule:
// 1. ops and args are stack, and FILO
// 2. all elements in ops stack should also be in ASC order
// 3.1. if ops stack is empty, OR the newOp have higher order than any ops in the stack
// the newOp is pushed into ops stack
// 3.2 otherwise, existing ops stack should be reduced by poping top ops with same operand/order
// and combining corresponding args at stack top into single value
// 3.3 repeat 3.1 & 3.2 untils condition 2 is satisfied
//
// e.g.
// A || !(B || C && ! D) || !E || ! !F
// 1 3 11 12 13 1 3 1 3 3
// ^
// A || !(B || C && ?) || !E || ! !F [? = !D]
// 1 3 11 12 1 3 1 3 3
// ^
// A || !(B || ?) || !E || ! !F [? = C && !D]
// 1 3 11 1 3 1 3 3
// ^
// A || ! ? || !E || ! !F [? = B || C && !D]
// 1 3 1 3 1 3 3
// ^
// Done, move on to next operand
// A || ? || !E || ! !F [? = !(B || C && !D)]
// 1 1 3 1 3 3
// ^
// ...
func processOperand(ops []operand, args []matcher.Matcher, newOp operand) ([]operand, []matcher.Matcher) {
// special case: empty ops
if len(ops) == 0 {
return append(ops, newOp), args
}
// repeat 3.1, 3.2 until all elements in ops stack have lower order than newOp
for {
// terminal condition: if top of the ops stack is not higher order than the target, we are done
if len(ops) == 0 || ops[len(ops) - 1].order <= newOp.order {
break
}
// start with last arg
last := ops[len(ops) - 1]
arg := len(args) - 1
var opCount int
for i := len(ops) - 1; i >= 0 && ops[i].order == last.order; i-- {
opCount ++
op := ops[i].op
if op != opNot {
// two arguments required, will reduce one more arg
arg --
}
if op != last.op || arg < 0 {
// 1. for each percedence, we only have one operand: either ||, $$ or !.
// so if we find something have different operand but same precedence, there must be samething wrong
// 2. not enough args
return nil, nil
}
}
// combine top of args stack into single matcher using the operand at the top operand stack
ops, args = combine(ops, args, len(ops) - opCount, arg)
}
// push newOp and return
ops = append(ops, newOp)
return ops, args
}
// combine N args at top of "args" stack into single arg, where N is specified via "aIdx".
// also remove M operands
// combine top of the stacks ("operands" starting at opIdx and "args" starting at aIdx)
func combine(ops []operand, args []matcher.Matcher, opIdx, aIdx int) ([]operand, []matcher.Matcher) {
op := ops[opIdx]
combined := args[aIdx]
switch op.op {
case opNot:
if (len(ops) - opIdx) % 2 == 1 {
combined = matcher.Not(combined)
}
case opOr:
combined = matcher.Or(args[aIdx], args[aIdx + 1:]...)
case opAnd:
combined = matcher.And(args[aIdx], args[aIdx + 1:]...)
}
ops = ops[:opIdx]
args = append(args[:aIdx], combined)
return ops, args
}
// nextToken assume there is no space in string
func nextToken(str string) (token string, remaining string) {
for i, _ := range str {
t := str[:i+1]
match, op := isEndWithOp(t)
if match {
idx := i + 1 - len(op)
if idx == 0 {
// the string begin with op
return op, str[i+1:]
}
// return string before the op
return t[:idx], str[idx:]
}
}
return str, ""
}
func isEndWithOp(str string) (match bool, op string) {
switch l := len(str); {
case l < 1:
return false, ""
case l < 2:
switch t := str[l-1:]; t {
case opNot, opOpen, opClose:
return true, t
}
default:
switch t := str[l-1:]; t {
case opNot, opOpen, opClose:
return true, t
}
switch t := str[l-2:]; t {
case opAnd, opOr:
return true, t
}
}
return false, ""
}
/**************************
Permission Matcher
***************************/
// permissionMatcher implements matcher.ChainableMatcher and accept map[string]interface{}
type permissionMatcher struct {
permission string
}
func NewPermissionMatcher(permission string) *permissionMatcher {
return &permissionMatcher{
permission: permission,
}
}
func (m *permissionMatcher) Matches(i interface{}) (bool, error) {
if perms, ok := i.(security.Permissions); ok {
return perms.Has(m.permission), nil
}
return false, nil
}
func (m *permissionMatcher) MatchesWithContext(_ context.Context, i interface{}) (bool, error) {
return m.Matches(i)
}
func (m *permissionMatcher) Or(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.Or(m, matchers...)
}
func (m *permissionMatcher) And(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.And(m, matchers...)
}
func (m *permissionMatcher) String() string {
return m.permission
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"context"
"strings"
"time"
)
/******************************
Abstraction - Basics
******************************/
type AccountType int
const (
AccountTypeUnknown AccountType = iota
AccountTypeDefault
AccountTypeApp
AccountTypeFederated
AccountTypeSystem
)
func (t AccountType) String() string {
switch t {
case AccountTypeDefault:
return "user"
case AccountTypeApp:
return "app"
case AccountTypeFederated:
return "fed"
case AccountTypeSystem:
return "system"
default:
return ""
}
}
func ParseAccountType(value interface{}) AccountType {
if v, ok := value.(AccountType); ok {
return v
}
switch v, ok := value.(string); ok {
case "user" == strings.ToLower(v):
return AccountTypeDefault
case "app" == strings.ToLower(v):
return AccountTypeApp
case "fed" == strings.ToLower(v):
return AccountTypeFederated
case "system" == strings.ToLower(v):
return AccountTypeSystem
default:
return AccountTypeUnknown
}
}
type Account interface {
ID() interface{}
Type() AccountType
Username() string
Credentials() interface{}
Permissions() []string
Disabled() bool
Locked() bool
UseMFA() bool
// CacheableCopy should returns a copy of Account that suitable for putting into cache.
// e.g. the CacheableCopy should be able to be serialized and shouldn't contains Credentials or any reloadable content
CacheableCopy() Account
}
type AccountFinalizeOption struct {
Tenant *Tenant // Tenant field can be nil
}
type AccountFinalizeOptions func(option *AccountFinalizeOption)
func FinalizeWithTenant(tenant *Tenant) AccountFinalizeOptions {
return func(option *AccountFinalizeOption) {
option.Tenant = tenant
}
}
type AccountFinalizer interface {
// Finalize is a function that will allow a service to modify the account before it
// is put into the security context. An example usage of this is to allow for per-tenant
// permissions where a user can have different permissions depending on which tenant is selected.
//
// Note that the Account.ID and Account.Username should not be changed. If those fields are changed
// an error will be reported.
Finalize(ctx context.Context, account Account, options ...AccountFinalizeOptions) (Account, error)
}
type AccountStore interface {
// LoadAccountById find account by its Domain
LoadAccountById(ctx context.Context, id interface{}) (Account, error)
// LoadAccountByUsername find account by its Username
LoadAccountByUsername(ctx context.Context, username string) (Account, error)
// LoadLockingRules load given account's locking rule. It's recommended to cache the result
LoadLockingRules(ctx context.Context, acct Account) (AccountLockingRule, error)
// LoadPwdAgingRules load given account's password policy. It's recommended to cache the result
LoadPwdAgingRules(ctx context.Context, acct Account) (AccountPwdAgingRule, error)
// Save save the account if necessary
Save(ctx context.Context, acct Account) error
}
type AutoCreateUserDetails interface {
IsEnabled() bool
GetEmailWhiteList() []string
GetAttributeMapping() map[string]string
GetElevatedUserRoleNames() []string
GetRegularUserRoleNames() []string
}
type FederatedAccountStore interface {
LoadAccountByExternalId(ctx context.Context, externalIdName string, externalIdValue string, externalIdpName string, autoCreateUserDetails AutoCreateUserDetails, rawAssertion interface{}) (Account, error)
}
/*********************************
Abstraction - Auth History
*********************************/
type AccountHistory interface {
LastLoginTime() time.Time
LoginFailures() []time.Time
SerialFailedAttempts() int
LockoutTime() time.Time
PwdChangedTime() time.Time
GracefulAuthCount() int
}
/********************************
Abstraction - Multi Tenancy
*********************************/
type AccountTenancy interface {
DefaultDesignatedTenantId() string
DesignatedTenantIds() []string
TenantId() string
}
/*********************************
Abstraction - Mutator
*********************************/
type AccountUpdater interface {
IncrementGracefulAuthCount()
ResetGracefulAuthCount()
LockAccount()
UnlockAccount()
RecordFailure(failureTime time.Time, limit int)
RecordSuccess(loginTime time.Time)
ResetFailedAttempts()
}
/*********************************
Abstraction - Locking Rules
*********************************/
type AccountLockingRule interface {
// LockoutPolicyName the name of locking rule
LockoutPolicyName() string
// LockoutEnabled indicate whether account locking is enabled
LockoutEnabled() bool
// LockoutDuration specify how long the account should be locked after consecutive login failures
LockoutDuration() time.Duration
// LockoutFailuresLimit specify how many consecutive login failures required to lock the account
LockoutFailuresLimit() int
// LockoutFailuresInterval specify how long between the first and the last login failures to be considered as consecutive login failures
LockoutFailuresInterval() time.Duration
}
/*********************************
Abstraction - Aging Rules
*********************************/
type AccountPwdAgingRule interface {
// PwdAgingPolicyName the name of password polcy
PwdAgingPolicyName() string
// PwdAgingRuleEnforced indicate whether password policy is enabled
PwdAgingRuleEnforced() bool
// PwdMaxAge specify how long a password is valid before expiry
PwdMaxAge() time.Duration
// PwdExpiryWarningPeriod specify how long before password expiry the system should warn user
PwdExpiryWarningPeriod() time.Duration
// GracefulAuthLimit specify how many logins is allowed after password expiry
GracefulAuthLimit() int
}
/*********************************
Abstraction - Metadata
*********************************/
type AccountMetadata interface {
RoleNames() []string
FirstName() string
LastName() string
Email() string
LocaleCode() string
CurrencyCode() string
Value(key string) interface{}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"time"
)
type AcctDetails struct {
ID string
Type AccountType
Username string
Credentials interface{}
Permissions []string
Disabled bool
Locked bool
UseMFA bool
DefaultDesignatedTenantId string
DesignatedTenantIds []string
TenantId string
LastLoginTime time.Time
LoginFailures []time.Time
SerialFailedAttempts int
LockoutTime time.Time
PwdChangedTime time.Time
GracefulAuthCount int
PolicyName string
}
type AcctLockingRule struct {
Name string
Enabled bool
LockoutDuration time.Duration
FailuresLimit int
FailuresInterval time.Duration
}
type AcctPasswordPolicy struct {
Name string
Enabled bool
MaxAge time.Duration
ExpiryWarningPeriod time.Duration
GracefulAuthLimit int
}
type AcctMetadata struct {
RoleNames []string
FirstName string
LastName string
Email string
LocaleCode string
CurrencyCode string
Extra map[string]interface{}
}
type DefaultAccount struct {
AcctDetails
AcctLockingRule
AcctPasswordPolicy
AcctMetadata
}
func NewUsernamePasswordAccount(details *AcctDetails) *DefaultAccount {
return &DefaultAccount{AcctDetails: *details}
}
/***********************************
implements security.Account
***********************************/
func (a *DefaultAccount) ID() interface{} {
return a.AcctDetails.ID
}
func (a *DefaultAccount) Type() AccountType {
return a.AcctDetails.Type
}
func (a *DefaultAccount) Username() string {
return a.AcctDetails.Username
}
func (a *DefaultAccount) Credentials() interface{} {
return a.AcctDetails.Credentials
}
func (a *DefaultAccount) Permissions() []string {
return a.AcctDetails.Permissions
}
func (a *DefaultAccount) Disabled() bool {
return a.AcctDetails.Disabled
}
func (a *DefaultAccount) Locked() bool {
return a.AcctDetails.Locked
}
func (a *DefaultAccount) UseMFA() bool {
return a.AcctDetails.UseMFA
}
func (a *DefaultAccount) CacheableCopy() Account {
cp := DefaultAccount{
AcctDetails: a.AcctDetails,
AcctMetadata: a.AcctMetadata,
}
cp.AcctDetails.Credentials = nil
return &cp
}
/***********************************
implements security.AccountTenancy
***********************************/
func (a *DefaultAccount) DefaultDesignatedTenantId() string {
return a.AcctDetails.DefaultDesignatedTenantId
}
func (a *DefaultAccount) DesignatedTenantIds() []string {
return a.AcctDetails.DesignatedTenantIds
}
func (a *DefaultAccount) TenantId() string {
return a.AcctDetails.TenantId
}
/***********************************
implements security.AccountHistory
***********************************/
func (a *DefaultAccount) LastLoginTime() time.Time {
return a.AcctDetails.LastLoginTime
}
func (a *DefaultAccount) LoginFailures() []time.Time {
return a.AcctDetails.LoginFailures
}
func (a *DefaultAccount) SerialFailedAttempts() int {
return a.AcctDetails.SerialFailedAttempts
}
func (a *DefaultAccount) PwdChangedTime() time.Time {
return a.AcctDetails.PwdChangedTime
}
func (a *DefaultAccount) GracefulAuthCount() int {
return a.AcctDetails.GracefulAuthCount
}
/***********************************
security.AccountUpdater
***********************************/
func (a *DefaultAccount) IncrementGracefulAuthCount() {
a.AcctDetails.GracefulAuthCount++
}
func (a *DefaultAccount) LockAccount() {
if !a.AcctDetails.Locked {
a.AcctDetails.LockoutTime = time.Now()
}
a.AcctDetails.Locked = true
}
func (a *DefaultAccount) UnlockAccount() {
// we don't clear lockout time for record keeping purpose
a.AcctDetails.Locked = false
}
func (a *DefaultAccount) RecordFailure(failureTime time.Time, limit int) {
failures := append(a.AcctDetails.LoginFailures, failureTime)
if len(failures) > limit {
failures = failures[len(failures)-limit:]
}
a.AcctDetails.LoginFailures = failures
a.AcctDetails.SerialFailedAttempts = len(failures)
}
func (a *DefaultAccount) RecordSuccess(loginTime time.Time) {
a.AcctDetails.LastLoginTime = loginTime
}
func (a *DefaultAccount) ResetFailedAttempts() {
a.AcctDetails.SerialFailedAttempts = 0
a.AcctDetails.LoginFailures = []time.Time{}
}
func (a *DefaultAccount) ResetGracefulAuthCount() {
a.AcctDetails.GracefulAuthCount = 0
}
/***********************************
security.AccountLockingRule
***********************************/
func (a *DefaultAccount) LockoutPolicyName() string {
return a.AcctLockingRule.Name
}
func (a *DefaultAccount) LockoutEnabled() bool {
return a.AcctLockingRule.Enabled
}
func (a *DefaultAccount) LockoutTime() time.Time {
return a.AcctDetails.LockoutTime
}
func (a *DefaultAccount) LockoutDuration() time.Duration {
return a.AcctLockingRule.LockoutDuration
}
func (a *DefaultAccount) LockoutFailuresLimit() int {
return a.AcctLockingRule.FailuresLimit
}
func (a *DefaultAccount) LockoutFailuresInterval() time.Duration {
return a.AcctLockingRule.FailuresInterval
}
/***********************************
security.AccountPwdAgingRule
***********************************/
func (a *DefaultAccount) PwdAgingPolicyName() string {
return a.AcctPasswordPolicy.Name
}
func (a *DefaultAccount) PwdAgingRuleEnforced() bool {
return a.AcctPasswordPolicy.Enabled
}
func (a *DefaultAccount) PwdMaxAge() time.Duration {
return a.AcctPasswordPolicy.MaxAge
}
func (a *DefaultAccount) PwdExpiryWarningPeriod() time.Duration {
return a.AcctPasswordPolicy.ExpiryWarningPeriod
}
func (a *DefaultAccount) GracefulAuthLimit() int {
return a.AcctPasswordPolicy.GracefulAuthLimit
}
/***********************************
security.AcctMetadata
***********************************/
func (a *DefaultAccount) RoleNames() []string {
if a.AcctMetadata.RoleNames == nil {
return []string{}
}
return a.AcctMetadata.RoleNames
}
func (a *DefaultAccount) FirstName() string {
return a.AcctMetadata.FirstName
}
func (a *DefaultAccount) LastName() string {
return a.AcctMetadata.LastName
}
func (a *DefaultAccount) Email() string {
return a.AcctMetadata.Email
}
func (a *DefaultAccount) LocaleCode() string {
return a.AcctMetadata.LocaleCode
}
func (a *DefaultAccount) CurrencyCode() string {
return a.AcctMetadata.CurrencyCode
}
func (a *DefaultAccount) Value(key string) interface{} {
if a.AcctMetadata.Extra == nil {
return nil
}
v, ok := a.AcctMetadata.Extra[key]
if !ok {
return nil
}
return v
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import "context"
type AnonymousCandidate map[string]interface{}
// Principal implements security.Candidate
func (ac AnonymousCandidate) Principal() interface{} {
return "anonymous"
}
// Credentials implements security.Candidate
func (_ AnonymousCandidate) Credentials() interface{} {
return ""
}
// Details implements security.Candidate
func (ac AnonymousCandidate) Details() interface{} {
return ac
}
type AnonymousAuthentication struct {
candidate AnonymousCandidate
}
func (aa *AnonymousAuthentication) Principal() interface{} {
return aa.candidate.Principal()
}
func (_ *AnonymousAuthentication) Permissions() Permissions {
return map[string]interface{}{}
}
func (_ *AnonymousAuthentication) State() AuthenticationState {
return StateAnonymous
}
func (aa *AnonymousAuthentication) Details() interface{} {
return aa.candidate.Details()
}
type AnonymousAuthenticator struct{}
func (a *AnonymousAuthenticator) Authenticate(_ context.Context, candidate Candidate) (auth Authentication, err error) {
if ac, ok := candidate.(AnonymousCandidate); ok {
return &AnonymousAuthentication{candidate: ac}, nil
}
return nil, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"net/http"
"sort"
"sync"
)
/*****************************
Abstraction
*****************************/
type Candidate interface {
Principal() interface{}
Credentials() interface{}
Details() interface{}
}
type Authenticator interface {
// Authenticate function takes the Candidate and authenticate it.
// if the Candidate type is not supported, return nil,nil
// if the Candidate is rejected, non-nil error, and the returned Authentication is ignored
Authenticate(context.Context, Candidate) (Authentication, error)
}
type AuthenticatorBuilder interface {
Build(context.Context) (Authenticator, error)
}
// AuthenticationSuccessHandler handles authentication success event
// The counterpart of this interface is AuthenticationErrorHandler
type AuthenticationSuccessHandler interface {
HandleAuthenticationSuccess(c context.Context, r *http.Request, rw http.ResponseWriter, from, to Authentication)
}
/*****************************
Common Impl.
*****************************/
// CompositeAuthenticator implement Authenticator interface
type CompositeAuthenticator struct {
init sync.Once
authenticators []Authenticator
flattened []Authenticator
}
func NewAuthenticator(authenticators ...Authenticator) Authenticator {
ret := &CompositeAuthenticator{}
ret.authenticators = ret.processAuthenticators(authenticators)
return ret
}
func (a *CompositeAuthenticator) Authenticate(ctx context.Context, candidate Candidate) (auth Authentication, err error) {
a.init.Do(func() {a.flattened = a.Authenticators()})
for _, authenticator := range a.flattened {
auth, err = authenticator.Authenticate(ctx, candidate)
if auth != nil || err != nil {
return
}
}
return nil, NewAuthenticatorNotAvailableError(fmt.Sprintf("unable to find authenticator for cadidate %T", candidate))
}
// Authenticators returns list of authenticators, any nested composite handlers are flattened
func (a *CompositeAuthenticator) Authenticators() []Authenticator {
flattened := make([]Authenticator, 0, len(a.authenticators))
for _, handler := range a.authenticators {
switch v := handler.(type) {
case *CompositeAuthenticator:
flattened = append(flattened, v.Authenticators()...)
default:
flattened = append(flattened, handler)
}
}
sort.SliceStable(flattened, func(i, j int) bool {
return order.OrderedFirstCompare(flattened[i], flattened[j])
})
return flattened
}
func (a *CompositeAuthenticator) Add(authenticator Authenticator) *CompositeAuthenticator {
a.authenticators = a.processAuthenticators(append(a.authenticators, authenticator))
sort.SliceStable(a.authenticators, func(i, j int) bool {
return order.OrderedFirstCompare(a.authenticators[i], a.authenticators[j])
})
return a
}
func (a *CompositeAuthenticator) Merge(composite *CompositeAuthenticator) *CompositeAuthenticator {
a.authenticators = a.processAuthenticators(append(a.authenticators, composite.authenticators...))
return a
}
func (a *CompositeAuthenticator) processAuthenticators(authenticators []Authenticator) []Authenticator {
// remove self
authenticators = a.removeSelf(authenticators)
sort.SliceStable(authenticators, func(i, j int) bool {
return order.OrderedFirstCompare(authenticators[i], authenticators[j])
})
return authenticators
}
func (a *CompositeAuthenticator) removeSelf(authenticators []Authenticator) []Authenticator {
count := 0
for _, item := range authenticators {
if ptr, ok := item.(*CompositeAuthenticator); !ok || ptr != a {
// copy and increment index
authenticators[count] = item
count++
}
}
// Prevent memory leak by erasing truncated values
for j := count; j < len(authenticators); j++ {
authenticators[j] = nil
}
return authenticators[:count]
}
// CompositeAuthenticationSuccessHandler implement AuthenticationSuccessHandler interface
type CompositeAuthenticationSuccessHandler struct {
init sync.Once
handlers []AuthenticationSuccessHandler
flattened []AuthenticationSuccessHandler
}
func NewAuthenticationSuccessHandler(handlers ...AuthenticationSuccessHandler) *CompositeAuthenticationSuccessHandler {
ret := &CompositeAuthenticationSuccessHandler{}
ret.handlers = ret.processSuccessHandlers(handlers)
return ret
}
func (h *CompositeAuthenticationSuccessHandler) HandleAuthenticationSuccess(
c context.Context, r *http.Request, rw http.ResponseWriter, from, to Authentication) {
h.init.Do(func() { h.flattened = h.Handlers() })
for _, handler := range h.flattened {
handler.HandleAuthenticationSuccess(c, r, rw, from, to)
}
}
// Handlers returns list of authentication handlers, any nested composite handlers are flattened
func (h *CompositeAuthenticationSuccessHandler) Handlers() []AuthenticationSuccessHandler {
flattened := make([]AuthenticationSuccessHandler, 0, len(h.handlers))
for _, handler := range h.handlers {
switch v := handler.(type) {
case *CompositeAuthenticationSuccessHandler:
flattened = append(flattened, v.Handlers()...)
default:
flattened = append(flattened, handler)
}
}
sort.SliceStable(flattened, func(i, j int) bool {
return order.OrderedFirstCompare(flattened[i], flattened[j])
})
return flattened
}
func (h *CompositeAuthenticationSuccessHandler) Add(handler AuthenticationSuccessHandler) *CompositeAuthenticationSuccessHandler {
h.handlers = h.processSuccessHandlers(append(h.handlers, handler))
return h
}
func (h *CompositeAuthenticationSuccessHandler) Merge(composite *CompositeAuthenticationSuccessHandler) *CompositeAuthenticationSuccessHandler {
h.handlers = h.processSuccessHandlers(append(h.handlers, composite.handlers...))
return h
}
func (h *CompositeAuthenticationSuccessHandler) processSuccessHandlers(handlers []AuthenticationSuccessHandler) []AuthenticationSuccessHandler {
handlers = h.removeSelf(handlers)
sort.SliceStable(handlers, func(i, j int) bool {
return order.OrderedFirstCompare(handlers[i], handlers[j])
})
return handlers
}
func (h *CompositeAuthenticationSuccessHandler) removeSelf(items []AuthenticationSuccessHandler) []AuthenticationSuccessHandler {
count := 0
for _, item := range items {
if ptr, ok := item.(*CompositeAuthenticationSuccessHandler); !ok || ptr != h {
// copy and increment index
items[count] = item
count++
}
}
// Prevent memory leak by erasing truncated values
for j := count; j < len(items); j++ {
items[j] = nil
}
return items[:count]
}
// CompositeAuthenticatorBuilder implements AuthenticatorBuilder
type CompositeAuthenticatorBuilder struct {
builders []AuthenticatorBuilder
}
func NewAuthenticatorBuilder() *CompositeAuthenticatorBuilder {
return &CompositeAuthenticatorBuilder{builders: []AuthenticatorBuilder{}}
}
func (b *CompositeAuthenticatorBuilder) Build(c context.Context) (Authenticator, error) {
authenticators := make([]Authenticator, len(b.builders))
for i, builder := range b.builders {
a, err := builder.Build(c)
if err != nil {
return nil, err
}
authenticators[i] = a
}
return NewAuthenticator(authenticators...), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package basicauth
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
var (
FeatureId = security.FeatureId("BasicAuth", security.FeatureOrderBasicAuth)
)
// We currently don't have any stuff to configure
//goland:noinspection GoNameStartsWithPackageName
type BasicAuthFeature struct {
entryPoint security.AuthenticationEntryPoint
}
// Standard security.Feature entrypoint
func (f *BasicAuthFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
func Configure(ws security.WebSecurity) *BasicAuthFeature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*BasicAuthFeature)
}
panic(fmt.Errorf("unable to configure session: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *BasicAuthFeature {
return &BasicAuthFeature{
entryPoint: NewBasicAuthEntryPoint(),
}
}
func (f *BasicAuthFeature) EntryPoint(entrypoint security.AuthenticationEntryPoint) *BasicAuthFeature {
f.entryPoint = entrypoint
return f
}
//goland:noinspection GoNameStartsWithPackageName
type BasicAuthConfigurer struct {
}
func newBasicAuthConfigurer() *BasicAuthConfigurer {
return &BasicAuthConfigurer{
}
}
func (bac *BasicAuthConfigurer) Apply(f security.Feature, ws security.WebSecurity) error {
// additional error handling
errorHandler := ws.Shared(security.WSSharedKeyCompositeAuthErrorHandler).(*security.CompositeAuthenticationErrorHandler)
errorHandler.Add(NewBasicAuthErrorHandler())
// default is NewBasicAuthEntryPoint(). But security.Configurer have chance to overwrite it or unset it
if entrypoint := f.(*BasicAuthFeature).entryPoint; entrypoint != nil {
errorhandling.Configure(ws).
AuthenticationEntryPoint(entrypoint)
}
// configure middlewares
basicAuth := NewBasicAuthMiddleware(
ws.Authenticator(),
ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(security.AuthenticationSuccessHandler),
)
auth := middleware.NewBuilder("basic auth").
Order(security.MWOrderBasicAuth).
Use(basicAuth.HandlerFunc())
ws.Add(auth)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package basicauth
import (
"context"
"encoding/base64"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"github.com/gin-gonic/gin"
"net/http"
"strconv"
"strings"
)
//goland:noinspection GoNameStartsWithPackageName
type BasicAuthMiddleware struct {
authenticator security.Authenticator
successHandler security.AuthenticationSuccessHandler
}
func NewBasicAuthMiddleware(authenticator security.Authenticator, successHandler security.AuthenticationSuccessHandler) *BasicAuthMiddleware {
return &BasicAuthMiddleware{
authenticator: authenticator,
successHandler: successHandler,
}
}
func (basic *BasicAuthMiddleware) HandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
header := ctx.GetHeader("Authorization")
if header == "" {
// Authorization header not available, bail
return
}
if !strings.HasPrefix(header,"Basic ") {
// Not basic auth, bail
return
}
encoded := strings.TrimPrefix(header, "Basic ")
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
basic.handleError(ctx, security.NewBadCredentialsError("invalid Authorization header"))
return
}
pair := strings.SplitN(string(decoded), ":", 2)
if len(pair) < 2 {
basic.handleError(ctx, security.NewBadCredentialsError("invalid Authorization header"))
return
}
before := security.Get(ctx)
currentAuth, ok := before.(passwd.UsernamePasswordAuthentication)
if ok && passwd.IsSamePrincipal(pair[0], currentAuth) {
// already authenticated
basic.handleSuccess(ctx, before, nil)
return
}
candidate := passwd.UsernamePasswordPair{
Username: pair[0],
Password: pair[1],
}
// Search auth in the slice of allowed credentials
auth, err := basic.authenticator.Authenticate(ctx, &candidate)
if err != nil {
basic.handleError(ctx, err)
return
}
basic.handleSuccess(ctx, before, auth)
}
}
func (basic *BasicAuthMiddleware) handleSuccess(c *gin.Context, before, new security.Authentication) {
if new != nil {
security.MustSet(c, new)
basic.successHandler.HandleAuthenticationSuccess(c, c.Request, c.Writer, before, new)
}
c.Next()
}
func (basic *BasicAuthMiddleware) handleError(c *gin.Context, err error) {
security.MustClear(c)
_ = c.Error(err)
c.Abort()
}
//goland:noinspection GoNameStartsWithPackageName
type BasicAuthEntryPoint struct {
security.DefaultAuthenticationErrorHandler
}
func NewBasicAuthEntryPoint() *BasicAuthEntryPoint {
return &BasicAuthEntryPoint{}
}
func (h *BasicAuthEntryPoint) Commence(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
writeBasicAuthChallenge(rw, err)
h.DefaultAuthenticationErrorHandler.HandleAuthenticationError(c, r, rw, err)
}
//goland:noinspection GoNameStartsWithPackageName
type BasicAuthErrorHandler struct {
}
func NewBasicAuthErrorHandler() *BasicAuthErrorHandler {
return &BasicAuthErrorHandler{}
}
func (h *BasicAuthErrorHandler) HandleAuthenticationError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
writeBasicAuthChallenge(rw, err)
}
func writeBasicAuthChallenge(rw http.ResponseWriter, err error) {
realm := "Basic realm=" + strconv.Quote("Authorization Required")
rw.Header().Set("WWW-Authenticate", realm)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package basicauth
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "basic auth",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newBasicAuthConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authserver
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/config/compatibility"
"github.com/cisco-open/go-lanai/pkg/security/idp"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/grants"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/openid"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/revoke"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/common"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/tokenauth"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"go.uber.org/fx"
"net/url"
)
const (
OrderAuthorizeSecurityConfigurer = 0
OrderLogoutSecurityConfigurer = 50
OrderClientAuthSecurityConfigurer = 100
OrderTokenAuthSecurityConfigurer = 200
)
type AuthorizationServerConfigurer func(*Configuration)
type configDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
Properties AuthServerProperties
Configurer AuthorizationServerConfigurer
RedisClientFactory redis.ClientFactory
ServerProperties web.ServerProperties
SessionProperties security.SessionProperties
CryptoProperties jwt.CryptoProperties
SessionStore session.Store
TimeoutSupport oauth2.TimeoutApplier `optional:"true"`
ApprovalStore auth.ApprovalStore `optional:"true"`
}
type authServerOut struct {
fx.Out
Config *Configuration
CompatibilityCustomizer discovery.ServiceRegistrationCustomizer `group:"discovery"`
}
//goland:noinspection GoExportedFuncWithUnexportedType
func ProvideAuthServerDI(di configDI) authServerOut {
config := Configuration{
appContext: di.AppContext,
redisClientFactory: di.RedisClientFactory,
sessionStore: di.SessionStore,
properties: di.Properties,
serverProperties: di.ServerProperties,
sessionProperties: di.SessionProperties,
cryptoProperties: di.CryptoProperties,
Issuer: newIssuer(&di.Properties.Issuer, &di.ServerProperties),
timeoutSupport: di.TimeoutSupport,
ApprovalStore: di.ApprovalStore,
Endpoints: Endpoints{
Authorize: ConditionalEndpoint{
Location: &url.URL{Path: di.Properties.Endpoints.Authorize},
Condition: matcher.NotRequest(matcher.RequestWithForm(oauth2.ParameterGrantType, samlctx.GrantTypeSamlSSO)),
},
Approval: di.Properties.Endpoints.Approval,
Token: di.Properties.Endpoints.Token,
CheckToken: di.Properties.Endpoints.CheckToken,
UserInfo: di.Properties.Endpoints.UserInfo,
JwkSet: di.Properties.Endpoints.JwkSet,
Error: di.Properties.Endpoints.Error,
Logout: di.Properties.Endpoints.Logout,
LoggedOut: di.Properties.Endpoints.LoggedOut,
SamlSso: ConditionalEndpoint{
Location: &url.URL{Path: di.Properties.Endpoints.Authorize, RawQuery: fmt.Sprintf("%s=%s", oauth2.ParameterGrantType, samlctx.GrantTypeSamlSSO)},
Condition: matcher.RequestWithForm(oauth2.ParameterGrantType, samlctx.GrantTypeSamlSSO),
},
SamlMetadata: di.Properties.Endpoints.SamlMetadata,
TenantHierarchy: di.Properties.Endpoints.TenantHierarchy,
},
OpenIDSSOEnabled: true,
}
di.Configurer(&config)
return authServerOut{
Config: &config,
CompatibilityCustomizer: compatibility.CompatibilityDiscoveryCustomizer{},
}
}
type initDI struct {
fx.In
Config *Configuration
WebRegistrar *web.Registrar
SecurityRegistrar security.Registrar
}
// ConfigureAuthorizationServer is the Configuration entry point
func ConfigureAuthorizationServer(di initDI) {
// Securities
di.SecurityRegistrar.Register(&ClientAuthEndpointsConfigurer{config: di.Config})
di.SecurityRegistrar.Register(&TokenAuthEndpointsConfigurer{config: di.Config})
for _, configuer := range di.Config.idpConfigurers {
di.SecurityRegistrar.Register(&AuthorizeEndpointConfigurer{config: di.Config, delegate: configuer})
}
di.SecurityRegistrar.Register(&LogoutEndpointConfigurer{config: di.Config, delegates: di.Config.idpConfigurers})
// Additional endpoints and other web configurations
di.WebRegistrar.WarnDuplicateMiddlewares(true,
di.Config.Endpoints.Authorize.Location.Path,
di.Config.Endpoints.SamlSso.Location.Path,
di.Config.Endpoints.Approval,
di.Config.Endpoints.Logout,
)
registerEndpoints(di.WebRegistrar, di.Config)
}
/****************************
configuration
****************************/
type ConditionalEndpoint struct {
Location *url.URL
Condition web.RequestMatcher
}
type Endpoints struct {
Authorize ConditionalEndpoint
Approval string
Token string
CheckToken string
UserInfo string
JwkSet string
Logout string
LoggedOut string
Error string
SamlSso ConditionalEndpoint
SamlMetadata string
TenantHierarchy string
}
type Configuration struct {
// configurable items
SessionSettingService session.SettingService
ClientStore oauth2.OAuth2ClientStore
ClientSecretEncoder passwd.PasswordEncoder
Endpoints Endpoints
UserAccountStore security.AccountStore
TenantStore security.TenantStore
ProviderStore security.ProviderStore
UserPasswordEncoder passwd.PasswordEncoder
TokenStore auth.TokenStore
JwkStore jwt.JwkStore
IdpManager idp.IdentityProviderManager
Issuer security.Issuer
OpenIDSSOEnabled bool
SamlIdpSigningMethod string
ApprovalStore auth.ApprovalStore
CustomTokenGranter []auth.TokenGranter
CustomTokenEnhancer []auth.TokenEnhancer
CustomAuthRegistry auth.AuthorizationRegistry
// not directly configurable items
appContext *bootstrap.ApplicationContext
redisClientFactory redis.ClientFactory
sessionStore session.Store
properties AuthServerProperties
serverProperties web.ServerProperties
sessionProperties security.SessionProperties
cryptoProperties jwt.CryptoProperties
idpConfigurers []IdpSecurityConfigurer
sharedContextDetailsStore security.ContextDetailsStore
sharedAuthRegistry auth.AuthorizationRegistry
sharedAccessRevoker auth.AccessRevoker
sharedErrorHandler *auth.OAuth2ErrorHandler
sharedTokenGranter auth.TokenGranter
sharedAuthService auth.AuthorizationService
sharedPasswdAuthenticator security.Authenticator
sharedJwtEncoder jwt.JwtEncoder
sharedJwtDecoder jwt.JwtDecoder
sharedDetailsFactory *common.ContextDetailsFactory
sharedARProcessor auth.AuthorizeRequestProcessor
sharedAuthHandler auth.AuthorizeHandler
sharedAuthCodeStore auth.AuthorizationCodeStore
sharedTokenAuthenticator security.Authenticator
timeoutSupport oauth2.TimeoutApplier
}
func (c *Configuration) AddIdp(configurer IdpSecurityConfigurer) {
c.idpConfigurers = append(c.idpConfigurers, configurer)
}
func newIssuer(props *IssuerProperties, serverProps *web.ServerProperties) security.Issuer {
contextPath := props.ContextPath
if contextPath == "" {
contextPath = serverProps.ContextPath
}
return security.NewIssuer(func(opt *security.DefaultIssuerDetails) {
*opt = security.DefaultIssuerDetails{
Protocol: props.Protocol,
Domain: props.Domain,
Port: props.Port,
ContextPath: contextPath,
IncludePort: props.IncludePort,
}
})
}
func (c *Configuration) clientSecretEncoder() passwd.PasswordEncoder {
if c.ClientSecretEncoder == nil {
c.ClientSecretEncoder = passwd.NewNoopPasswordEncoder()
}
return c.ClientSecretEncoder
}
func (c *Configuration) userPasswordEncoder() passwd.PasswordEncoder {
if c.UserPasswordEncoder == nil {
c.UserPasswordEncoder = passwd.NewNoopPasswordEncoder()
}
return c.UserPasswordEncoder
}
func (c *Configuration) errorHandler() *auth.OAuth2ErrorHandler {
if c.sharedErrorHandler == nil {
c.sharedErrorHandler = auth.NewOAuth2ErrorHandler()
}
return c.sharedErrorHandler
}
func (c *Configuration) tokenGranter() auth.TokenGranter {
if c.sharedTokenGranter == nil {
granters := []auth.TokenGranter{
grants.NewAuthorizationCodeGranter(c.authorizationService(), c.authorizeCodeStore()),
grants.NewClientCredentialsGranter(c.authorizationService()),
grants.NewRefreshGranter(c.authorizationService(), c.tokenStore()),
grants.NewSwitchUserGranter(c.authorizationService(), c.tokenAuthenticator(), c.UserAccountStore),
grants.NewSwitchTenantGranter(c.authorizationService(), c.tokenAuthenticator(), c.UserAccountStore),
}
// password granter is optional
if c.passwordGrantAuthenticator() != nil {
passwordGranter := grants.NewPasswordGranter(c.authorizationService(), c.passwordGrantAuthenticator())
granters = append(granters, passwordGranter)
}
for _, custom := range c.CustomTokenGranter {
switch v := custom.(type) {
case auth.AuthorizationServiceInjector:
v.Inject(c.authorizationService())
default:
// do nothing
}
}
granters = append(granters, c.CustomTokenGranter...)
c.sharedTokenGranter = auth.NewCompositeTokenGranter(granters...)
}
return c.sharedTokenGranter
}
func (c *Configuration) passwordGrantAuthenticator() security.Authenticator {
if c.sharedPasswdAuthenticator == nil && c.UserAccountStore != nil {
authenticator, err := passwd.NewAuthenticatorBuilder(
passwd.New().
AccountStore(c.UserAccountStore).
PasswordEncoder(c.userPasswordEncoder()).
MFA(false),
).Build(context.Background())
if err == nil {
c.sharedPasswdAuthenticator = authenticator
}
}
return c.sharedPasswdAuthenticator
}
func (c *Configuration) contextDetailsStore() security.ContextDetailsStore {
if c.sharedContextDetailsStore == nil {
c.sharedContextDetailsStore = common.NewRedisContextDetailsStore(c.appContext, c.redisClientFactory, c.timeoutSupport)
}
return c.sharedContextDetailsStore
}
func (c *Configuration) authorizationRegistry() auth.AuthorizationRegistry {
if c.sharedAuthRegistry == nil {
if c.CustomAuthRegistry != nil {
c.sharedAuthRegistry = c.CustomAuthRegistry
} else {
c.sharedAuthRegistry = c.contextDetailsStore().(auth.AuthorizationRegistry)
}
}
return c.sharedAuthRegistry
}
func (c *Configuration) tokenStore() auth.TokenStore {
if c.TokenStore == nil {
c.TokenStore = auth.NewJwtTokenStore(func(opt *auth.JTSOption) {
opt.DetailsStore = c.contextDetailsStore()
opt.Encoder = c.jwtEncoder()
opt.Decoder = c.jwtDecoder()
opt.AuthRegistry = c.authorizationRegistry()
})
}
return c.TokenStore
}
func (c *Configuration) authorizationService() auth.AuthorizationService {
if c.sharedAuthService == nil {
c.sharedAuthService = auth.NewDefaultAuthorizationService(func(conf *auth.DASOption) {
conf.TokenStore = c.tokenStore()
conf.DetailsFactory = c.contextDetailsFactory()
conf.Issuer = c.Issuer
conf.ClientStore = c.ClientStore
conf.AccountStore = c.UserAccountStore
conf.TenantStore = c.TenantStore
conf.ProviderStore = c.ProviderStore
if c.OpenIDSSOEnabled {
openidEnhancer := openid.NewOpenIDTokenEnhancer(func(opt *openid.EnhancerOption) {
opt.Issuer = c.Issuer
opt.JwtEncoder = c.jwtEncoder()
})
conf.PostTokenEnhancers = append(conf.PostTokenEnhancers, openidEnhancer)
}
conf.TokenEnhancers = append(conf.TokenEnhancers, c.CustomTokenEnhancer...)
})
}
return c.sharedAuthService
}
func (c *Configuration) jwkStore() jwt.JwkStore {
if c.JwkStore == nil {
c.JwkStore = jwt.NewFileJwkStore(c.cryptoProperties)
}
return c.JwkStore
}
func (c *Configuration) jwtEncoder() jwt.JwtEncoder {
if c.sharedJwtEncoder == nil {
c.sharedJwtEncoder = jwt.NewSignedJwtEncoder(jwt.SignWithJwkStore(c.jwkStore(), c.cryptoProperties.Jwt.KeyName))
}
return c.sharedJwtEncoder
}
func (c *Configuration) jwtDecoder() jwt.JwtDecoder {
if c.sharedJwtDecoder == nil {
c.sharedJwtDecoder = jwt.NewSignedJwtDecoder(jwt.VerifyWithJwkStore(c.jwkStore(), c.cryptoProperties.Jwt.KeyName))
}
return c.sharedJwtDecoder
}
func (c *Configuration) contextDetailsFactory() *common.ContextDetailsFactory {
if c.sharedDetailsFactory == nil {
c.sharedDetailsFactory = common.NewContextDetailsFactory()
}
return c.sharedDetailsFactory
}
func (c *Configuration) authorizeRequestProcessor() auth.AuthorizeRequestProcessor {
if c.sharedARProcessor == nil {
processors := []auth.ChainedAuthorizeRequestProcessor{
auth.NewStandardAuthorizeRequestProcessor(func(opt *auth.StdARPOption) {
opt.ClientStore = c.ClientStore
opt.AccountStore = c.UserAccountStore
}),
}
if c.OpenIDSSOEnabled {
p := openid.NewOpenIDAuthorizeRequestProcessor(func(opt *openid.ARPOption) {
opt.Issuer = c.Issuer
opt.JwtDecoder = c.jwtDecoder()
})
processors = append([]auth.ChainedAuthorizeRequestProcessor{p}, processors...)
}
c.sharedARProcessor = auth.NewAuthorizeRequestProcessor(processors...)
}
return c.sharedARProcessor
}
func (c *Configuration) authorizeHandler() auth.AuthorizeHandler {
if c.sharedAuthHandler == nil {
//TODO OIDC Implicit flow extension
c.sharedAuthHandler = auth.NewAuthorizeHandler(func(opt *auth.AuthHandlerOption) {
//opt.Extensions = OIDC extensions
opt.ApprovalPageTmpl = "authorize.tmpl"
opt.ApprovalUrl = c.Endpoints.Approval
opt.AuthService = c.authorizationService()
opt.AuthCodeStore = c.authorizeCodeStore()
})
}
return c.sharedAuthHandler
}
func (c *Configuration) authorizeCodeStore() auth.AuthorizationCodeStore {
if c.sharedAuthCodeStore == nil {
c.sharedAuthCodeStore = auth.NewRedisAuthorizationCodeStore(c.appContext, c.redisClientFactory, c.sessionProperties.DbIndex)
}
return c.sharedAuthCodeStore
}
func (c *Configuration) tokenAuthenticator() security.Authenticator {
if c.sharedTokenAuthenticator == nil {
c.sharedTokenAuthenticator = tokenauth.NewAuthenticator(func(opt *tokenauth.AuthenticatorOption) {
opt.TokenStoreReader = c.tokenStore()
})
}
return c.sharedTokenAuthenticator
}
func (c *Configuration) accessRevoker() auth.AccessRevoker {
if c.sharedAccessRevoker == nil {
c.sharedAccessRevoker = revoke.NewDefaultAccessRevoker(func(opt *revoke.RevokerOption) {
opt.AuthRegistry = c.authorizationRegistry()
opt.SessionStore = c.sessionStore
opt.TokenStoreReader = c.tokenStore()
})
}
return c.sharedAccessRevoker
}
func (c *Configuration) approvalStore() auth.ApprovalStore {
return c.ApprovalStore
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authserver
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/misc"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/openid"
utils_matcher "github.com/cisco-open/go-lanai/pkg/utils/matcher"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/rest"
"github.com/cisco-open/go-lanai/pkg/web/template"
)
func registerEndpoints(registrar *web.Registrar, config *Configuration) {
jwks := misc.NewJwkSetEndpoint(config.jwkStore())
ct := misc.NewCheckTokenEndpoint(config.Issuer, config.tokenStore())
ui := misc.NewUserInfoEndpoint(config.Issuer, config.UserAccountStore, config.jwtEncoder())
th := misc.NewTenantHierarchyEndpoint()
mappings := []interface{}{
template.New().Get(config.Endpoints.Error).HandlerFunc(errorhandling.ErrorWithStatus).Build(),
rest.New("jwks").Get(config.Endpoints.JwkSet).EndpointFunc(jwks.JwkSet).Build(),
rest.New("jwks/kid").Get(config.Endpoints.JwkSet+"/:kid").EndpointFunc(jwks.JwkByKid).Build(),
rest.New("check_token").Post(config.Endpoints.CheckToken).EndpointFunc(ct.CheckToken).Build(),
rest.New("userinfo GET").Get(config.Endpoints.UserInfo).
Condition(acceptJwtMatcher()).
EncodeResponseFunc(misc.JwtResponseEncoder()).
EndpointFunc(ui.JwtUserInfo).Build(),
rest.New("userinfo GET").Get(config.Endpoints.UserInfo).
Condition(notAcceptJwtMatcher()).EndpointFunc(ui.PlainUserInfo).Build(),
rest.New("userinfo POST").Post(config.Endpoints.UserInfo).
Condition(acceptJwtMatcher()).
EncodeResponseFunc(misc.JwtResponseEncoder()).
EndpointFunc(ui.JwtUserInfo).Build(),
rest.New("userinfo POST").Post(config.Endpoints.UserInfo).
Condition(notAcceptJwtMatcher()).
EndpointFunc(ui.PlainUserInfo).Build(),
rest.New("tenant hierarchy parent").Get(fmt.Sprintf("%s/%s", config.Endpoints.TenantHierarchy, "parent")).
EndpointFunc(th.GetParent).EncodeResponseFunc(misc.StringResponseEncoder()).Build(),
rest.New("tenant hierarchy children").Get(fmt.Sprintf("%s/%s", config.Endpoints.TenantHierarchy, "children")).
EndpointFunc(th.GetChildren).Build(),
rest.New("tenant hierarchy ancestors").Get(fmt.Sprintf("%s/%s", config.Endpoints.TenantHierarchy, "ancestors")).
EndpointFunc(th.GetAncestors).Build(),
rest.New("tenant hierarchy descendants").Get(fmt.Sprintf("%s/%s", config.Endpoints.TenantHierarchy, "descendants")).
EndpointFunc(th.GetDescendants).Build(),
rest.New("tenant hierarchy root").Get(fmt.Sprintf("%s/%s", config.Endpoints.TenantHierarchy, "root")).
EndpointFunc(th.GetRoot).EncodeResponseFunc(misc.StringResponseEncoder()).Build(),
}
// openid additional
if config.OpenIDSSOEnabled {
opConf := prepareWellKnownEndpoint(config)
mappings = append(mappings,
rest.New("openid-config").Get(openid.WellKnownEndpointOPConfig).
EndpointFunc(opConf.OpenIDConfig).Build(),
)
}
registrar.MustRegister(mappings...)
}
func acceptJwtMatcher() web.RequestMatcher {
return matcher.RequestWithHeader("Accept", "application/jwt", true)
}
func notAcceptJwtMatcher() web.RequestMatcher {
return utils_matcher.Not(matcher.RequestWithHeader("Accept", "application/jwt", true))
}
func prepareWellKnownEndpoint(config *Configuration) *misc.WellKnownEndpoint {
extra := map[string]interface{}{
openid.OPMetadataAuthEndpoint: config.Endpoints.Authorize.Location.Path,
openid.OPMetadataTokenEndpoint: config.Endpoints.Token,
openid.OPMetadataUserInfoEndpoint: config.Endpoints.UserInfo,
openid.OPMetadataJwkSetURI: config.Endpoints.JwkSet,
openid.OPMetadataEndSessionEndpoint: config.Endpoints.Logout,
}
return misc.NewWellKnownEndpoint(config.Issuer, config.IdpManager, extra)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authserver
import (
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/timeoutsupport"
samlidp "github.com/cisco-open/go-lanai/pkg/security/saml/idp"
th_loader "github.com/cisco-open/go-lanai/pkg/tenancy/loader"
"go.uber.org/fx"
)
//go:embed defaults-authserver.yml
var defaultConfigFS embed.FS
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "oauth2 authserver",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(BindAuthServerProperties),
fx.Provide(ProvideAuthServerDI),
fx.Provide(provide),
fx.Invoke(ConfigureAuthorizationServer),
},
}
func Use() {
security.Use()
th_loader.Use()
samlidp.Use() // saml_auth enables SAML SSO/SLO
bootstrap.Register(Module)
timeoutsupport.Use()
// Note: External SAML IDP support (samllogin package) is enabled as part of samlidp
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authserver
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/pkg/errors"
)
const (
PropertiesPrefix = "security.auth"
)
//goland:noinspection GoNameStartsWithPackageName
type AuthServerProperties struct {
Issuer IssuerProperties `json:"issuer"`
RedirectWhitelist []string `json:"redirect-whitelist"`
Endpoints EndpointsProperties `json:"endpoints"`
}
type IssuerProperties struct {
// the protocol which is either http or https
Protocol string `json:"protocol"`
// This server's host name
// Used to build the entity base url. The entity url identifies this auth server in a SAML exchange and OIDC exchange.
Domain string `json:"domain"`
Port int `json:"port"`
// Context base path for this server
// Used to build the entity base url. The entity url identifies this auth server in a SAML exchange.
ContextPath string `json:"context-path"`
IncludePort bool `json:"include-port"`
}
type EndpointsProperties struct {
Authorize string `json:"authorize"`
Token string `json:"token"`
Approval string `json:"approval"`
CheckToken string `json:"check-token"`
TenantHierarchy string `json:"tenant-hierarchy"`
Error string `json:"error"`
Logout string `json:"logout"`
LoggedOut string `json:"logged-out"`
UserInfo string `json:"user-info"`
JwkSet string `json:"jwk-set"`
SamlMetadata string `json:"saml-metadata"`
}
// NewAuthServerProperties create a SessionProperties with default values
func NewAuthServerProperties() *AuthServerProperties {
return &AuthServerProperties{
Issuer: IssuerProperties{
Protocol: "http",
Domain: "locahost",
Port: 8080,
ContextPath: "",
IncludePort: true,
},
RedirectWhitelist: []string{},
Endpoints: EndpointsProperties{
Authorize: "/v2/authorize",
Token: "/v2/token",
Approval: "/v2/approve",
CheckToken: "/v2/check_token",
TenantHierarchy: "/v2/tenant_hierarchy",
Error: "/error",
Logout: "/v2/logout",
UserInfo: "/v2/userinfo",
JwkSet: "/v2/jwks",
SamlMetadata: "/metadata",
LoggedOut: "/",
},
}
}
// BindAuthServerProperties create and bind AuthServerProperties, with a optional prefix
func BindAuthServerProperties(ctx *bootstrap.ApplicationContext) AuthServerProperties {
props := NewAuthServerProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind AuthServerProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authserver
import (
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"go.uber.org/fx"
)
type provideDI struct {
fx.In
Config *Configuration
}
type provideOut struct {
fx.Out
AccessRevoker auth.AccessRevoker
}
func provide(di provideDI) provideOut {
return provideOut{
AccessRevoker: di.Config.accessRevoker(),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authserver
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/csrf"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/logout"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/authorize"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/clientauth"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/openid"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/revoke"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/token"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/tokenauth"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/security/request_cache"
"github.com/cisco-open/go-lanai/pkg/security/saml/idp"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
)
/***************************
additional abstractions
***************************/
// IdpSecurityConfigurer interface for IDPs to implement for customizing "authorize" process
type IdpSecurityConfigurer interface {
Configure(ws security.WebSecurity, config *Configuration)
}
// IdpLogoutSecurityConfigurer additional interface that IdpSecurityConfigurer could choose to implement for
// customizing "logout" process
// Note: IdpLogoutSecurityConfigurer is only invoked once per instance, the given security.WebSecurity are shared
//
// between IDPs. Therefore, implementing class should not change Route or Condition on the given "ws"
type IdpLogoutSecurityConfigurer interface {
ConfigureLogout(ws security.WebSecurity, config *Configuration)
}
/***************************
security configurers
***************************/
// ClientAuthEndpointsConfigurer implements security.Configurer and order.Ordered
// responsible to configure misc using client auth
type ClientAuthEndpointsConfigurer struct {
config *Configuration
}
func (c *ClientAuthEndpointsConfigurer) Order() int {
return OrderClientAuthSecurityConfigurer
}
func (c *ClientAuthEndpointsConfigurer) Configure(ws security.WebSecurity) {
// For Token endpoint
ws.Route(matcher.RouteWithPattern(c.config.Endpoints.Token)).
Route(matcher.RouteWithPattern(c.config.Endpoints.CheckToken)).
Route(matcher.RouteWithPattern(fmt.Sprintf("%s/*", c.config.Endpoints.TenantHierarchy))).
With(clientauth.New().
ClientStore(c.config.ClientStore).
ClientSecretEncoder(c.config.clientSecretEncoder()).
ErrorHandler(c.config.errorHandler()).
AllowForm(true), // AllowForm also implicitly enables Public Client
).
// uncomment following if we want CheckToken to not allow public client
//With(access.Configure(ws).
// Request(matcher.RequestWithPattern(c.config.Endpoints.CheckToken)).
// AllowIf(access.HasPermissionsWithExpr("!public_client")),
//).
With(token.NewEndpoint().
Path(c.config.Endpoints.Token).
AddGranter(c.config.tokenGranter()),
)
}
// TokenAuthEndpointsConfigurer implements security.Configurer and order.Ordered
// responsible to configure misc using token auth
type TokenAuthEndpointsConfigurer struct {
config *Configuration
}
func (c *TokenAuthEndpointsConfigurer) Order() int {
return OrderTokenAuthSecurityConfigurer
}
func (c *TokenAuthEndpointsConfigurer) Configure(ws security.WebSecurity) {
// For Token endpoint
ws.Route(matcher.RouteWithPattern(c.config.Endpoints.UserInfo)).
With(tokenauth.New().
EnablePostBody(),
).
With(access.New().
Request(matcher.AnyRequest()).Authenticated(),
).
With(errorhandling.New())
}
// AuthorizeEndpointConfigurer implements security.Configurer and order.Ordered
// responsible to configure "authorize" endpoint
type AuthorizeEndpointConfigurer struct {
config *Configuration
delegate IdpSecurityConfigurer
}
func (c *AuthorizeEndpointConfigurer) Order() int {
return OrderAuthorizeSecurityConfigurer
}
func (c *AuthorizeEndpointConfigurer) Configure(ws security.WebSecurity) {
path := c.config.Endpoints.Authorize.Location.Path
condition := c.config.Endpoints.Authorize.Condition
ws.Route(matcher.RouteWithPattern(path)).
With(authorize.NewEndpoint().
Path(path).
Condition(condition).
ApprovalPath(c.config.Endpoints.Approval).
RequestProcessor(c.config.authorizeRequestProcessor()).
ErrorHandler(c.config.errorHandler()).
AuthorizeHanlder(c.config.authorizeHandler()).
ApprovalStore(c.config.approvalStore()),
).
Route(matcher.RouteWithPattern(c.config.Endpoints.SamlSso.Location.Path)).
With(samlidp.New().
Issuer(c.config.Issuer).
SsoCondition(c.config.Endpoints.SamlSso.Condition).
SsoLocation(c.config.Endpoints.SamlSso.Location).
MetadataPath(c.config.Endpoints.SamlMetadata).
EnableSLO(c.config.Endpoints.Logout).
SigningMethod(c.config.SamlIdpSigningMethod))
c.delegate.Configure(ws, c.config)
}
// LogoutEndpointConfigurer implements security.Configurer and order.Ordered
// responsible to configure "logout" endpoint
type LogoutEndpointConfigurer struct {
config *Configuration
delegates []IdpSecurityConfigurer
}
func (c *LogoutEndpointConfigurer) Order() int {
return OrderLogoutSecurityConfigurer
}
func (c *LogoutEndpointConfigurer) Configure(ws security.WebSecurity) {
// Logout Handler
// Note: we disable default logout errHandler here because we don't want to unauthenticate user when PUT or DELETE is used
logoutHandler := revoke.NewTokenRevokingLogoutHandler(func(opt *revoke.HanlderOption) {
opt.Revoker = c.config.accessRevoker()
})
logoutSuccessHandler := revoke.NewTokenRevokeSuccessHandler(func(opt *revoke.SuccessOption) {
opt.ClientStore = c.config.ClientStore
opt.WhitelabelErrorPath = c.config.Endpoints.Error
opt.RedirectWhitelist = utils.NewStringSet(c.config.properties.RedirectWhitelist...)
opt.WhitelabelLoggedOutPath = c.config.Endpoints.LoggedOut
})
oidcLogoutHandler := openid.NewOidcLogoutHandler(func(opt *openid.HandlerOption) {
opt.Dec = c.config.sharedJwtDecoder
opt.Issuer = c.config.Issuer
opt.ClientStore = c.config.ClientStore
})
oidcLogoutSuccessHandler := openid.NewOidcSuccessHandler(func(opt *openid.SuccessOption) {
opt.ClientStore = c.config.ClientStore
opt.WhitelabelErrorPath = c.config.Endpoints.Error
})
oidcEntryPoint := openid.NewOidcEntryPoint(func(opt *openid.EpOption) {
opt.WhitelabelErrorPath = c.config.Endpoints.Error
})
errHandler := redirect.NewRedirectWithURL(c.config.Endpoints.Error)
ws.With(session.New().SettingService(c.config.SessionSettingService)).
With(access.New().
Request(matcher.AnyRequest()).Authenticated(),
).
With(errorhandling.New().
AccessDeniedHandler(errHandler),
).
With(csrf.New().
IgnoreCsrfProtectionMatcher(matcher.RequestWithPattern(c.config.Endpoints.Logout)),
).
With(request_cache.New()).
With(logout.New().
LogoutUrl(c.config.Endpoints.Logout).
// By using this instead of AddLogoutHandler, the default logout handler is disabled.
LogoutHandlers(logoutHandler, oidcLogoutHandler).
AddSuccessHandler(logoutSuccessHandler).
AddSuccessHandler(oidcLogoutSuccessHandler).
AddEntryPoint(oidcEntryPoint),
).
With(samlidp.NewLogout().
Issuer(c.config.Issuer).
SsoCondition(c.config.Endpoints.SamlSso.Condition).
SsoLocation(c.config.Endpoints.SamlSso.Location).
MetadataPath(c.config.Endpoints.SamlMetadata).
EnableSLO(c.config.Endpoints.Logout),
)
for _, configurer := range c.delegates {
if lc, ok := configurer.(IdpLogoutSecurityConfigurer); ok {
lc.ConfigureLogout(ws, c.config)
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package compatibility
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/security"
)
// CompatibilityDiscoveryCustomizer implements discovery.ServiceRegistrationCustomizer
type CompatibilityDiscoveryCustomizer struct {}
func (c CompatibilityDiscoveryCustomizer) Customize(_ context.Context, reg discovery.ServiceRegistration) {
tag := fmt.Sprintf("%s=%s", security.CompatibilityReferenceTag, security.CompatibilityReference)
reg.AddTags(tag)
reg.SetMeta(security.CompatibilityReferenceTag, security.CompatibilityReference)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package resserver
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/config/compatibility"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/common"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/tokenauth"
"go.uber.org/fx"
)
type ResourceServerConfigurer func(*Configuration)
type resServerConfigDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
RedisClientFactory redis.ClientFactory
CryptoProperties jwt.CryptoProperties
TimeoutSupport oauth2.TimeoutApplier `optional:"true"`
Configurer ResourceServerConfigurer
}
type resServerOut struct {
fx.Out
Config *Configuration
TokenStore oauth2.TokenStoreReader
CompatibilityCustomizer discovery.ServiceRegistrationCustomizer `group:"discovery"`
}
//goland:noinspection GoExportedFuncWithUnexportedType,HttpUrlsUsage
func ProvideResServerDI(di resServerConfigDI) resServerOut {
config := Configuration{
appContext: di.AppContext,
cryptoProperties: di.CryptoProperties,
redisClientFactory: di.RedisClientFactory,
timeoutSupport: di.TimeoutSupport,
RemoteEndpoints: RemoteEndpoints{
Token: "http://authserver/v2/token",
CheckToken: "http://authserver/v2/check_token",
UserInfo: "http://authserver/v2/userinfo",
JwkSet: "http://authserver/v2/jwks",
},
}
di.Configurer(&config)
return resServerOut{
Config: &config,
TokenStore: config.SharedTokenStoreReader(),
CompatibilityCustomizer: compatibility.CompatibilityDiscoveryCustomizer{},
}
}
type resServerDI struct {
fx.In
Config *Configuration
SecurityRegistrar security.Registrar
}
// ConfigureResourceServer configuration entry point
func ConfigureResourceServer(di resServerDI) {
// register token auth feature
configurer := tokenauth.NewTokenAuthConfigurer(func(opt *tokenauth.TokenAuthOption) {
opt.TokenStoreReader = di.Config.tokenStoreReader()
})
di.SecurityRegistrar.(security.FeatureRegistrar).RegisterFeature(tokenauth.FeatureId, configurer)
}
/****************************
configuration
****************************/
type RemoteEndpoints struct {
Token string
CheckToken string
UserInfo string
JwkSet string
}
type Configuration struct {
// configurable items
RemoteEndpoints RemoteEndpoints
TokenStoreReader oauth2.TokenStoreReader
JwkStore jwt.JwkStore
// not directly configurable items
appContext *bootstrap.ApplicationContext
redisClientFactory redis.ClientFactory
cryptoProperties jwt.CryptoProperties
sharedTokenAuthenticator security.Authenticator
sharedErrorHandler *tokenauth.OAuth2ErrorHandler
sharedContextDetailsStore security.ContextDetailsStore
sharedJwtDecoder jwt.JwtDecoder
timeoutSupport oauth2.TimeoutApplier
}
func (c *Configuration) SharedTokenStoreReader() oauth2.TokenStoreReader {
return c.tokenStoreReader()
}
func (c *Configuration) errorHandler() *tokenauth.OAuth2ErrorHandler {
if c.sharedErrorHandler == nil {
c.sharedErrorHandler = tokenauth.NewOAuth2ErrorHanlder()
}
return c.sharedErrorHandler
}
func (c *Configuration) contextDetailsStore() security.ContextDetailsStore {
if c.sharedContextDetailsStore == nil {
c.sharedContextDetailsStore = common.NewRedisContextDetailsStore(c.appContext, c.redisClientFactory, c.timeoutSupport)
}
return c.sharedContextDetailsStore
}
func (c *Configuration) tokenStoreReader() oauth2.TokenStoreReader {
if c.TokenStoreReader == nil {
c.TokenStoreReader = common.NewJwtTokenStoreReader(func(opt *common.JTSROption) {
opt.DetailsStore = c.contextDetailsStore()
opt.Decoder = c.jwtDecoder()
})
}
return c.TokenStoreReader
}
func (c *Configuration) jwkStore() jwt.JwkStore {
if c.JwkStore == nil {
c.JwkStore = jwt.NewFileJwkStore(c.cryptoProperties)
}
return c.JwkStore
}
func (c *Configuration) jwtDecoder() jwt.JwtDecoder {
if c.sharedJwtDecoder == nil {
c.sharedJwtDecoder = jwt.NewSignedJwtDecoder(jwt.VerifyWithJwkStore(c.jwkStore(), c.cryptoProperties.Jwt.KeyName))
}
return c.sharedJwtDecoder
}
func (c *Configuration) tokenAuthenticator() security.Authenticator {
if c.sharedTokenAuthenticator == nil {
c.sharedTokenAuthenticator = tokenauth.NewAuthenticator(func(opt *tokenauth.AuthenticatorOption) {
opt.TokenStoreReader = c.tokenStoreReader()
})
}
return c.sharedTokenAuthenticator
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package resserver
import (
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/timeoutsupport"
"go.uber.org/fx"
)
//go:embed defaults-resserver.yml
var defaultConfigFS embed.FS
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "oauth2 authserver",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(jwt.BindCryptoProperties),
fx.Provide(ProvideResServerDI),
fx.Invoke(ConfigureResourceServer),
},
}
func Use() {
security.Use()
bootstrap.Register(Module)
timeoutsupport.Use()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/config/authserver"
"github.com/cisco-open/go-lanai/pkg/security/idp"
"github.com/cisco-open/go-lanai/pkg/security/idp/extsamlidp"
"github.com/cisco-open/go-lanai/pkg/security/idp/passwdidp"
"github.com/cisco-open/go-lanai/pkg/security/idp/unknownIdp"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"github.com/cisco-open/go-lanai/test/samltest"
"github.com/cisco-open/go-lanai/test/sectest"
"go.uber.org/fx"
)
const (
IdpDomainPasswd = "passwd.lanai.com"
IdpDomainExtSAML = "saml.lanai.com"
ExtSamlIdpName = "ext-saml-idp"
ExtSamlIdpEntityID = "http://external.saml.com/samlidp/metadata"
ExtSamlIdpSSOUrl = "http://external.saml.com/samlidp/authorize"
ExtSamlIdpSLOUrl = "http://external.saml.com/samlidp/logout"
)
type authDI struct {
fx.In
MockingProperties sectest.MockingProperties
IdpManager idp.IdentityProviderManager
AccountStore security.AccountStore
PasswordEncoder passwd.PasswordEncoder
Properties authserver.AuthServerProperties
PasswdIDPProperties passwdidp.PwdAuthProperties
SamlIDPProperties extsamlidp.SamlAuthProperties
CustomTokenGranter auth.TokenGranter `optional:"true"`
CustomTokenEnhancer auth.TokenEnhancer `optional:"true"`
CustomAuthRegistry auth.AuthorizationRegistry `optional:"true"`
}
func NewAuthServerConfigurer(di authDI) authserver.AuthorizationServerConfigurer {
return func(config *authserver.Configuration) {
// setup IDPs
config.AddIdp(passwdidp.NewPasswordIdpSecurityConfigurer(
passwdidp.WithProperties(&di.PasswdIDPProperties),
passwdidp.WithMFAListeners(),
))
config.AddIdp(extsamlidp.NewSamlIdpSecurityConfigurer(
extsamlidp.WithProperties(&di.SamlIDPProperties),
))
config.AddIdp(unknownIdp.NewNoIdpSecurityConfigurer())
config.IdpManager = di.IdpManager
config.ClientStore = sectest.NewMockedClientStore(di.MockingProperties.Clients.Values()...)
config.ClientSecretEncoder = di.PasswordEncoder
config.UserAccountStore = di.AccountStore
config.TenantStore = sectest.NewMockedTenantStore(di.MockingProperties.Tenants.Values()...)
config.ProviderStore = sectest.MockedProviderStore{}
config.UserPasswordEncoder = di.PasswordEncoder
config.SessionSettingService = StaticSessionSettingService(1)
if di.CustomTokenEnhancer != nil {
config.CustomTokenEnhancer = []auth.TokenEnhancer{di.CustomTokenEnhancer}
}
if di.CustomTokenGranter != nil {
config.CustomTokenGranter = []auth.TokenGranter{di.CustomTokenGranter}
}
if di.CustomAuthRegistry != nil {
config.CustomAuthRegistry = di.CustomAuthRegistry
}
}
}
type StaticSessionSettingService int
func (s StaticSessionSettingService) GetMaximumSessions(ctx context.Context) int {
return int(s)
}
func NewMockedIDPManager() *samltest.MockedIdpManager {
return samltest.NewMockedIdpManager(func(opt *samltest.IdpManagerMockOption) {
opt.IDPList = []idp.IdentityProvider{
extsamlidp.NewIdentityProvider(func(opt *extsamlidp.SamlIdpDetails) {
opt.EntityId = ExtSamlIdpEntityID
opt.Domain = IdpDomainExtSAML
opt.ExternalIdpName = ExtSamlIdpName
opt.ExternalIdName = "username"
opt.MetadataLocation = "testdata/ext-saml-metadata.xml"
}),
}
opt.Delegates = []idp.IdentityProviderManager{
sectest.NewMockedIDPManager(func(opt *sectest.IdpManagerMockOption) {
opt.PasswdIDPDomain = IdpDomainPasswd
}),
}
})
}
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
)
type CustomAuthRegistry struct {
RegistrationCount int
}
func NewCustomAuthRegistry() auth.AuthorizationRegistry {
return &CustomAuthRegistry{}
}
func (c *CustomAuthRegistry) RegisterRefreshToken(ctx context.Context, token oauth2.RefreshToken, oauth oauth2.Authentication) error {
panic("implement me")
}
func (c *CustomAuthRegistry) RegisterAccessToken(ctx context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) error {
c.RegistrationCount++
return nil
}
func (c *CustomAuthRegistry) ReadStoredAuthorization(ctx context.Context, token oauth2.RefreshToken) (oauth2.Authentication, error) {
panic("implement me")
}
func (c *CustomAuthRegistry) FindSessionId(ctx context.Context, token oauth2.Token) (string, error) {
panic("implement me")
}
func (c *CustomAuthRegistry) RevokeRefreshToken(ctx context.Context, token oauth2.RefreshToken) error {
panic("implement me")
}
func (c *CustomAuthRegistry) RevokeAccessToken(ctx context.Context, token oauth2.AccessToken) error {
panic("implement me")
}
func (c *CustomAuthRegistry) RevokeAllAccessTokens(ctx context.Context, token oauth2.RefreshToken) error {
panic("implement me")
}
func (c *CustomAuthRegistry) RevokeUserAccess(ctx context.Context, username string, revokeRefreshToken bool) error {
panic("implement me")
}
func (c *CustomAuthRegistry) RevokeClientAccess(ctx context.Context, clientId string, revokeRefreshToken bool) error {
panic("implement me")
}
func (c CustomAuthRegistry) RevokeSessionAccess(ctx context.Context, sessionId string, revokeRefreshToken bool) error {
panic("implement me")
}
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
)
type CustomClaims struct {
oauth2.FieldClaimsMapper
oauth2.Claims
MyClaim string `claim:"MyClaim"`
}
func (c *CustomClaims) MarshalJSON() ([]byte, error) {
return c.FieldClaimsMapper.DoMarshalJSON(c)
}
func (c *CustomClaims) UnmarshalJSON(bytes []byte) error {
return c.FieldClaimsMapper.DoUnmarshalJSON(c, bytes)
}
func (c *CustomClaims) Get(claim string) interface{} {
return c.FieldClaimsMapper.Get(c, claim)
}
func (c *CustomClaims) Has(claim string) bool {
return c.FieldClaimsMapper.Has(c, claim)
}
func (c *CustomClaims) Set(claim string, value interface{}) {
c.FieldClaimsMapper.Set(c, claim, value)
}
func (c *CustomClaims) Values() map[string]interface{} {
return c.FieldClaimsMapper.Values(c)
}
type CustomTokenEnhancer struct{}
func NewCustomTokenEnhancer() auth.TokenEnhancer {
return &CustomTokenEnhancer{}
}
func (c *CustomTokenEnhancer) Enhance(ctx context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError("unsupported token implementation %T", t)
}
if t.Claims() == nil {
return nil, oauth2.NewInternalError("need to be placed after BasicClaimsEnhancer")
}
customClaims := &CustomClaims{
Claims: t.Claims(),
}
customClaims.MyClaim = "my_claim_value"
t.SetClaims(customClaims)
return t, nil
}
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
type CustomContextDetails struct {
}
func (c CustomContextDetails) ExpiryTime() time.Time {
return time.Now().Add(time.Minute)
}
func (c CustomContextDetails) IssueTime() time.Time {
return time.Now()
}
func (c CustomContextDetails) Roles() utils.StringSet {
return utils.NewStringSet()
}
func (c CustomContextDetails) Permissions() utils.StringSet {
return utils.NewStringSet()
}
func (c CustomContextDetails) AuthenticationTime() time.Time {
return time.Now()
}
type CustomTokenGranter struct {
authService auth.AuthorizationService
}
func (c *CustomTokenGranter) Inject(authService auth.AuthorizationService) {
c.authService = authService
}
func NewCustomTokenGranter() auth.TokenGranter {
return &CustomTokenGranter{}
}
func (c *CustomTokenGranter) Grant(ctx context.Context, request *auth.TokenRequest) (oauth2.AccessToken, error) {
if "custom_grant" != request.GrantType {
return nil, nil
}
client, e := auth.RetrieveFullyAuthenticatedClient(ctx)
if e != nil {
return nil, oauth2.NewInvalidGrantError("requires client secret validated")
}
req := request.OAuth2Request(client)
oauth := oauth2.NewAuthentication(func(opt *oauth2.AuthOption) {
opt.Request = req
opt.UserAuth = oauth2.NewUserAuthentication(func(opt *oauth2.UserAuthOption) {
opt.Principal = "custom_grant_principal"
opt.State = security.StateAuthenticated
})
opt.Details = &CustomContextDetails{}
})
return c.authService.CreateAccessToken(ctx, oauth)
}
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
)
type MockedApprovalStore struct {
userApproval map[string][]*auth.Approval
}
func NewMockedApprovalStore() auth.ApprovalStore {
return &MockedApprovalStore{
userApproval: make(map[string][]*auth.Approval),
}
}
func (m *MockedApprovalStore) SaveApproval(c context.Context, a *auth.Approval) error {
approvals := m.userApproval[a.Username]
approvals = append(approvals, a)
m.userApproval[a.Username] = approvals
return nil
}
func (m *MockedApprovalStore) LoadApprovals(c context.Context, opts ...auth.ApprovalLoadOptions) ([]*auth.Approval, error) {
opt := &auth.Approval{}
for _, f := range opts {
f(opt)
}
approvals := m.userApproval[opt.Username]
var ret []*auth.Approval
for _, a := range approvals {
if a.ClientId == opt.ClientId {
ret = append(ret, a)
}
}
return ret, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/config/resserver"
"go.uber.org/fx"
)
type secDI struct {
fx.In
SecRegistrar security.Registrar
}
func NewResServerConfigurer(_ secDI) resserver.ResourceServerConfigurer {
return func(config *resserver.Configuration) {
//do nothing
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
)
type ErrorPageSecurityConfigurer struct {}
func (c *ErrorPageSecurityConfigurer) Configure(ws security.WebSecurity) {
ws.Route(matcher.RouteWithPattern("/error")).
With(session.New()).
With(access.New().
Request(matcher.AnyRequest()).PermitAll(),
)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"go.uber.org/fx"
)
/************************************
Interfaces for setting security
*************************************/
// Configurer can be registered to Registrar.
// Each Configurer will get a newly created WebSecurity and is responsible to configure for customized security
type Configurer interface {
Configure(WebSecurity)
}
type ConfigurerFunc func(ws WebSecurity)
func (f ConfigurerFunc) Configure(ws WebSecurity) {
f(ws)
}
/************************************
Interfaces for other modules
*************************************/
// Registrar is the entry point to configure security
type Registrar interface {
// Register is the entry point for all security configuration.
// Microservice or other library packages typically call this in Provide or Invoke functions
// Note: use this function inside fx.Lifecycle takes no effect
Register(...Configurer)
}
// Initializer is the entry point to bootstrap security
type Initializer interface {
// Initialize is the entry point for all security configuration.
// Microservice or other library packages typically call this in Provide or Invoke functions
// Note: use this function inside fx.Lifecycle takes no effect
Initialize(ctx context.Context, lc fx.Lifecycle, registrar *web.Registrar) error
}
/****************************************
Type definitions for
specifying web security specs
*****************************************/
// MiddlewareTemplate is partially configured middleware.MappingBuilder.
// it holds the middleware's gin.HandlerFunc and order
// if its route matcher and condition is not set, WebSecurity would make it matches WebSecurity's own values
type MiddlewareTemplate *middleware.MappingBuilder
// SimpleMappingTemplate is partially configured mapping.MappingBuilder
// it holds the simple mapping's path, gin.HandlerFunc and order
// if its condition is not set, WebSecurity would make it matches WebSecurity's own values
type SimpleMappingTemplate *mapping.MappingBuilder
// FeatureIdentifier is unique for each feature.
// Security initializer use this value to locate corresponding FeatureConfigurer
// or sort configuration order
type FeatureIdentifier interface {
fmt.Stringer
fmt.GoStringer
}
// Feature holds security settings of specific feature.
// Any Feature should have a corresponding FeatureConfigurer
type Feature interface {
Identifier() FeatureIdentifier
}
// WebSecurity holds information on security configuration
type WebSecurity interface {
// Context returns the context associated with the WebSecurity.
// It's typlically holds bootstrap.ApplicationContext or its derived context
// this should not returns nil
Context() context.Context
// Route configure the path and method pattern which this WebSecurity applies to
// Calling this method multiple times concatenate all given matchers with OR operator
Route(web.RouteMatcher) WebSecurity
// Condition sets additional conditions of incoming request which this WebSecurity applies to
// Calling this method multiple times concatenate all given matchers with OR operator
Condition(mwcm web.RequestMatcher) WebSecurity
// AndCondition sets additional conditions of incoming request which this WebSecurity applies to
// Calling this method multiple times concatenate all given matchers with AND operator
AndCondition(mwcm web.RequestMatcher) WebSecurity
// Add is DSL style setter to add:
// - MiddlewareTemplate
// - web.MiddlewareMapping
// - web.MvcMapping
// - web.StaticMapping
// - web.SimpleMapping
// when MiddlewareTemplate is given, WebSecurity's Route and Condition are applied to it
// this method panic if other type is given
Add(...interface{}) WebSecurity
// Remove is DSL style setter to remove:
// - MiddlewareTemplate
// - web.MiddlewareMapping
// - web.MvcMapping
// - web.StaticMapping
// - web.SimpleMapping
Remove(...interface{}) WebSecurity
// With is DSL style setter to enable features
With(f Feature) WebSecurity
// Shared get shared value
Shared(key string) interface{}
// AddShared add shared value. returns error when the key already exists
AddShared(key string, value interface{}) error
// Authenticator returns Authenticator
Authenticator() Authenticator
// Features get currently configured Feature list
Features() []Feature
}
/****************************************
Convenient Types
*****************************************/
type simpleFeatureId string
// String implements FeatureIdentifier interface
func (id simpleFeatureId) String() string {
return string(id)
}
// GoString implements FeatureIdentifier interface
func (id simpleFeatureId) GoString() string {
return string(id)
}
// SimpleFeatureId create unordered FeatureIdentifier
func SimpleFeatureId(id string) FeatureIdentifier {
return simpleFeatureId(id)
}
// featureId is ordered
type featureId struct {
id string
order int
}
// Order implements order.Ordered interface
func (id featureId) Order() int {
return id.order
}
// String implements FeatureIdentifier interface
func (id featureId) String() string {
return id.id
}
// GoString implements FeatureIdentifier interface
func (id featureId) GoString() string {
return id.id
}
// FeatureId create an ordered FeatureIdentifier
func FeatureId(id string, order int) FeatureIdentifier {
return featureId{id: id, order: order}
}
// priorityFeatureId is priority Ordered
type priorityFeatureId struct {
id string
order int
}
// PriorityOrder implements order.PriorityOrdered interface
func (id priorityFeatureId) PriorityOrder() int {
return id.order
}
// String implements FeatureIdentifier interface
func (id priorityFeatureId) String() string {
return id.id
}
// GoString implements FeatureIdentifier interface
func (id priorityFeatureId) GoString() string {
return id.id
}
// PriorityFeatureId create a priority ordered FeatureIdentifier
func PriorityFeatureId(id string, order int) FeatureIdentifier {
return priorityFeatureId{id: id, order: order}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"context"
"encoding/gob"
"errors"
"fmt"
securityinternal "github.com/cisco-open/go-lanai/internal/security"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"time"
)
const (
HighestMiddlewareOrder = int(-1<<18 + 1) // -0x3ffff = -262143
LowestMiddlewareOrder = HighestMiddlewareOrder + 0xffff // -0x30000 = -196608
)
type AuthenticationState int
const (
StateAnonymous = AuthenticationState(iota)
StatePrincipalKnown
StateAuthenticated
)
type Permissions map[string]interface{}
func (p Permissions) Has(permission string) bool {
_, ok := p[permission]
return ok
}
type Authentication interface {
Principal() interface{}
Permissions() Permissions
State() AuthenticationState
Details() interface{}
}
// EmptyAuthentication represent unauthenticated user.
// Note: anonymous user is considered authenticated
type EmptyAuthentication string
func (EmptyAuthentication) Principal() interface{} {
return nil
}
func (EmptyAuthentication) Details() interface{} {
return nil
}
func (EmptyAuthentication) State() AuthenticationState {
return StateAnonymous
}
func (EmptyAuthentication) Permissions() Permissions {
return map[string]interface{}{}
}
type GlobalSettingReader interface {
// Read setting of given key into "dest". Should support types:
// - *[]byte
// - *string
// - *bool
// - *int
Read(ctx context.Context, key string, dest interface{}) error
}
func GobRegister() {
gob.Register(EmptyAuthentication(""))
gob.Register((*AnonymousAuthentication)(nil))
gob.Register((*CodedError)(nil))
gob.Register(errors.New(""))
gob.Register((*DefaultAccount)(nil))
gob.Register((*AcctDetails)(nil))
gob.Register((*AcctLockingRule)(nil))
gob.Register((*AcctPasswordPolicy)(nil))
gob.Register((*AccountMetadata)(nil))
}
/**********************************
Common Functions
**********************************/
type securityCtxKey struct{}
func Get(ctx context.Context) Authentication {
secCtx, ok := ctx.Value(securityCtxKey{}).(Authentication)
if !ok {
secCtx = EmptyAuthentication("not authenticated")
}
return secCtx
}
// MustSet is the panicking version of Set.
func MustSet(ctx context.Context, auth Authentication) {
if e := Set(ctx, auth); e != nil {
panic(e)
}
}
// Set security context, return error if the given context is not backed by utils.MutableContext.
func Set(ctx context.Context, auth Authentication) error {
mc := utils.FindMutableContext(ctx)
if mc == nil {
return NewInternalError(fmt.Sprintf(`unable to set security into context: given context [%T] is not mutable`, ctx))
}
mc.Set(securityCtxKey{}, auth)
// optionally, set AuthUserKey into gin context if available
if gc := web.GinContext(ctx); gc != nil {
if auth == nil {
gc.Set(gin.AuthUserKey, nil)
} else {
gc.Set(gin.AuthUserKey, auth.Principal())
}
}
return nil
}
// MustClear set security context as "unauthenticated".
func MustClear(ctx context.Context) {
if e := Clear(ctx); e != nil {
panic(e)
}
}
// Clear attempt to set security context as "unauthenticated". Return error if not possible
func Clear(ctx context.Context) error {
return Set(ctx, EmptyAuthentication("not authenticated"))
}
func HasPermissions(auth Authentication, permissions ...string) bool {
for _, p := range permissions {
if !auth.Permissions().Has(p) {
return false
}
}
return true
}
// IsTenantValid In most cases, the HasAccessToTenant should be used instead. It checks both the tenant's validity and whether the user has access to it
func IsTenantValid(ctx context.Context, tenantId string) bool {
parentId, err := tenancy.GetParent(ctx, tenantId)
//if we find a parent, that means we have this tenantId in tenant hierarchy, so it's valid
if err == nil && parentId != "" {
return true
}
//it's also possible that the tenantId is the root tenant (root tenant doesn't have a parent so it won't appear in the call above)
rootId, err := tenancy.GetRoot(ctx)
if err == nil && rootId != "" && rootId == tenantId {
return true
}
return false
}
// HasAccessToTenant if no error return true, otherwise return false
func HasAccessToTenant(ctx context.Context, tenantId string) bool {
err := HasErrorAccessingTenant(ctx, tenantId)
return err == nil
}
// HasErrorAccessingTenant
/*
if the tenantId is not valid, this method will return false, otherwise the following checks are applied in order
1. If the user has ACCESS_ALL_TENANT permission, this method will return true
2. If the user's designated tenants include the give tenant, this method will return true
3. If the tenant hierarchy is loaded, this method will also check if any of the given tenant's ancestor
is in the user's designated tenant. If yes, this method will return true.
otherwise, this method return false.
*/
func HasErrorAccessingTenant(ctx context.Context, tenantId string) error {
if !IsTenantValid(ctx, tenantId) {
return ErrorInvalidTenantId
}
auth := Get(ctx)
if ud, ok := auth.Details().(securityinternal.TenantAccessDetails); ok {
if ud.EffectiveAssignedTenantIds().Has(SpecialTenantIdWildcard) {
return nil
}
hasDesc := tenancy.AnyHasDescendant(ctx, ud.EffectiveAssignedTenantIds(), tenantId)
if hasDesc {
return nil
}
}
return ErrorTenantAccessDenied
}
func IsFullyAuthenticated(auth Authentication) bool {
return auth.State() >= StateAuthenticated
}
func IsBeingAuthenticated(from, to Authentication) bool {
fromUnauthenticatedState := from == nil || from.State() < StateAuthenticated
toAuthenticatedState := to != nil && to.State() > StatePrincipalKnown
return fromUnauthenticatedState && toAuthenticatedState
}
func IsBeingUnAuthenticated(from, to Authentication) bool {
fromAuthenticated := from != nil && from.State() > StateAnonymous
toUnAuthenticatedState := to == nil || to.State() <= StateAnonymous
return fromAuthenticated && toUnAuthenticatedState
}
func DetermineAuthenticationTime(_ context.Context, userAuth Authentication) (authTime time.Time) {
if userAuth == nil {
return
}
details, ok := userAuth.Details().(map[string]interface{})
if !ok {
return
}
v, ok := details[DetailsKeyAuthTime]
if !ok {
return
}
switch t := v.(type) {
case time.Time:
authTime = t
case string:
authTime = utils.ParseTime(time.RFC3339, t)
}
return
}
func GetUsername(userAuth Authentication) (string, error) {
if userAuth == nil {
return "", fmt.Errorf("unsupported authentication is nil")
}
principal := userAuth.Principal()
var username string
switch principal.(type) {
case Account:
username = principal.(Account).Username()
case string:
username = principal.(string)
case fmt.Stringer:
username = principal.(fmt.Stringer).String()
default:
return "", fmt.Errorf("unsupported principal type %T", principal)
}
return username, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package csrf
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
/**
CSRF feature uses the synchronizer token pattern to prevent cross site request forgery
https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#synchronizer-token-pattern
*/
var FeatureId = security.FeatureId("csrf", security.FeatureOrderCsrf)
type Feature struct {
requireCsrfProtectionMatchers []web.RequestMatcher
ignoreCsrfProtectionMatchers []web.RequestMatcher
csrfDeniedHandler security.AccessDeniedHandler
}
func Configure(ws security.WebSecurity) *Feature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*Feature)
}
panic(fmt.Errorf("unable to configure CSRF: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
func New() *Feature {
return &Feature{}
}
func (f *Feature) AddCsrfProtectionMatcher(m web.RequestMatcher) *Feature {
f.requireCsrfProtectionMatchers = append(f.requireCsrfProtectionMatchers, m)
return f
}
func (f *Feature) IgnoreCsrfProtectionMatcher(m web.RequestMatcher) *Feature {
f.ignoreCsrfProtectionMatchers = append(f.ignoreCsrfProtectionMatchers, m)
return f
}
func (f *Feature) CsrfDeniedHandler(csrfDeniedHandler security.AccessDeniedHandler) *Feature {
f.csrfDeniedHandler = csrfDeniedHandler
return f
}
func (f *Feature) Identifier() security.FeatureIdentifier {
return FeatureId
}
type Configurer struct {
}
func newCsrfConfigurer() *Configurer{
return &Configurer{}
}
func (sc *Configurer) Apply(feature security.Feature, ws security.WebSecurity) error {
f := feature.(*Feature)
// configure additional access denied handler if provided
if f.csrfDeniedHandler != nil {
handler := &CsrfDeniedHandler{delegate: f.csrfDeniedHandler}
ws.Shared(security.WSSharedKeyCompositeAccessDeniedHandler).(*security.CompositeAccessDeniedHandler).
Add(handler)
}
tokenStore := newSessionBackedStore()
//Add authentication success handler
successHandler := &ChangeCsrfHandler{
csrfTokenStore: tokenStore,
}
ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(*security.CompositeAuthenticationSuccessHandler).
Add(successHandler)
// configure middleware
requiredCsrfProtectionMatcher := DefaultProtectionMatcher
for _, matcher := range f.requireCsrfProtectionMatchers {
requiredCsrfProtectionMatcher = requiredCsrfProtectionMatcher.Or(matcher)
}
ignoreCsrfProtectionMatcher := DefaultIgnoreMatcher
for _, matcher := range f.ignoreCsrfProtectionMatchers {
ignoreCsrfProtectionMatcher = ignoreCsrfProtectionMatcher.Or(matcher)
}
manager := newManager(tokenStore, requiredCsrfProtectionMatcher, ignoreCsrfProtectionMatcher)
csrfHandler := middleware.NewBuilder("csrfMiddleware").
Order(security.MWOrderCsrfHandling).
Use(manager.CsrfHandlerFunc())
ws.Add(csrfHandler)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package csrf
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"net/http"
)
type ChangeCsrfHandler struct{
csrfTokenStore TokenStore
}
func (h *ChangeCsrfHandler) HandleAuthenticationSuccess(c context.Context, _ *http.Request, _ http.ResponseWriter, from, to security.Authentication) {
if !security.IsBeingAuthenticated(from, to) {
return
}
// TODO: review error handling of this block
t, err := h.csrfTokenStore.LoadToken(c)
if err != nil {
panic(security.NewInternalError(err.Error()))
}
if t != nil {
t = h.csrfTokenStore.Generate(c, t.ParameterName, t.HeaderName)
if e := h.csrfTokenStore.SaveToken(c, t); e != nil {
panic(security.NewInternalError(err.Error()))
}
}
MustSet(c, t)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package csrf
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/gin-gonic/gin"
"net/http"
)
type csrfCtxKey struct{}
var DefaultProtectionMatcher = matcher.NotRequest(matcher.RequestWithMethods("GET", "HEAD", "TRACE", "OPTIONS"))
var DefaultIgnoreMatcher = matcher.NoneRequest()
// Get returns Token stored in given context. May return nil
func Get(c context.Context) *Token {
t, _ := c.Value(csrfCtxKey{}).(*Token)
return t
}
// MustSet is the panicking version of Set
func MustSet(c context.Context, t *Token) {
if e := Set(c, t); e != nil {
panic(e)
}
}
// Set given Token into given context. The function returns error if the given context is not backed by utils.MutableContext.
func Set(c context.Context, t *Token) error {
mc := utils.FindMutableContext(c)
if mc == nil {
return security.NewInternalError(fmt.Sprintf(`unable to set CSRF token into context: given context [%T] is not mutable`, c))
}
mc.Set(csrfCtxKey{}, t)
return nil
}
type manager struct {
tokenStore TokenStore
requireProtection web.RequestMatcher
ignoreProtection web.RequestMatcher
parameterName string
headerName string
}
func newManager(tokenStore TokenStore, csrfProtectionMatcher web.RequestMatcher, ignoreProtectionMatcher web.RequestMatcher) *manager {
if csrfProtectionMatcher == nil {
csrfProtectionMatcher = DefaultProtectionMatcher
}
if ignoreProtectionMatcher == nil {
ignoreProtectionMatcher = DefaultIgnoreMatcher
}
return &manager{
tokenStore: tokenStore,
parameterName: security.CsrfParamName,
headerName: security.CsrfHeaderName,
requireProtection: csrfProtectionMatcher,
ignoreProtection: ignoreProtectionMatcher,
}
}
func (m *manager) CsrfHandlerFunc() gin.HandlerFunc {
return func(c *gin.Context) {
expectedToken, err := m.tokenStore.LoadToken(c)
// this means there's something wrong with reading the csrf token from storage - e.g. can't deserialize or can't access storage
// this means we can't recover, so abort
if err != nil {
_ = c.Error(security.NewInternalError(err.Error()))
c.Abort()
}
if expectedToken == nil {
expectedToken = m.tokenStore.Generate(c, m.parameterName, m.headerName)
err = m.tokenStore.SaveToken(c, expectedToken)
if err != nil {
_ = c.Error(security.NewInternalError(err.Error()))
c.Abort()
}
}
//This so that the templates knows what to render to
//we don't depend on the value being stored in session to decouple it from the store implementation.
if e := Set(c, expectedToken); e != nil {
_ = c.Error(security.NewInternalError("request has invalid csrf token"))
c.Abort()
}
matches, err := m.requireProtection.MatchesWithContext(c, c.Request)
if err != nil {
_ = c.Error(security.NewInternalError(err.Error()))
c.Abort()
}
ignores, err := m.ignoreProtection.MatchesWithContext(c, c.Request)
if err != nil {
_ = c.Error(security.NewInternalError(err.Error()))
c.Abort()
}
if matches && !ignores {
actualToken := c.GetHeader(m.headerName)
if actualToken == "" {
actualToken, _ = c.GetPostForm(m.parameterName)
}
//both error case returns access denied, but the error message may be different
if actualToken == "" {
_ = c.Error(security.NewMissingCsrfTokenError("request is missing csrf token"))
c.Abort()
} else if actualToken != expectedToken.Value {
_ = c.Error(security.NewInvalidCsrfTokenError("request has invalid csrf token"))
c.Abort()
}
}
}
}
type CsrfDeniedHandler struct {
delegate security.AccessDeniedHandler
}
// Order implement order.Ordered
func (h *CsrfDeniedHandler) Order() int {
return 0
}
// HandleAccessDenied implement security.AccessDeniedHandler
func (h *CsrfDeniedHandler) HandleAccessDenied(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
switch {
case errors.Is(err, security.ErrorSubTypeCsrf):
h.delegate.HandleAccessDenied(c, r, rw, err)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package csrf
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web/template"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "csrf",
Precedence: security.MinSecurityPrecedence + 20, //after session
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
template.RegisterGlobalModelValuer(template.ModelKeyCsrf, template.ContextModelValuer(Get))
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newCsrfConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package csrf
import (
"context"
"encoding/gob"
"errors"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/google/uuid"
)
const SessionKeyCsrfToken = "CsrfToken"
// Token CSRF token with value and other useful metadata
/**
The header name and parameter name are part of the token in case some components down the line needs them.
For example, if the token is used as a hidden variable in a form, the parameter name would be needed.
*/
type Token struct {
Value string
// the HTTP parameter that the CSRF token can be placed on request
ParameterName string
// the HTTP header that the CSRF can be placed on requests instead of the parameter.
HeaderName string
}
// TokenStore
/**
The store is responsible for reading the CSRF token associated to the request.
How the CSRF token is associated to the request is the implementation's discretion.
The store is responsible for writing to the response header if necessary
for example, if the store implementation is based on cookies, then the save method
would write (save) the token as a cookie header.
*/
type TokenStore interface {
Generate(c context.Context, parameterName string, headerName string) *Token
SaveToken(c context.Context, token *Token) error
LoadToken(c context.Context) (*Token, error)
}
type SessionBackedStore struct {
}
func newSessionBackedStore() *SessionBackedStore{
gob.Register((*Token)(nil))
return &SessionBackedStore{}
}
func (store *SessionBackedStore) Generate(c context.Context, parameterName string, headerName string) *Token {
t := &Token{
Value: uuid.New().String(),
ParameterName: parameterName,
HeaderName: headerName,
}
return t
}
func (store *SessionBackedStore) SaveToken(c context.Context, token *Token) error {
s := session.Get(c)
if s == nil {
return errors.New("can't save csrf token to session, because the request has no session")
}
s.Set(SessionKeyCsrfToken, token)
return s.Save()
}
func (store *SessionBackedStore) LoadToken(c context.Context) (*Token, error) {
s := session.Get(c)
if s == nil {
return nil, errors.New("can't load csrf token from session, because the request has no session")
}
attr := s.Get(SessionKeyCsrfToken)
if attr == nil {
return nil, nil
}
if token, ok := attr.(*Token); !ok {
return nil, errors.New("csrf token in session can't be asserted to be the CSRF token type")
} else {
return token, nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/template"
"github.com/gin-gonic/gin"
"net/http"
"sort"
"strings"
"sync"
)
// AccessDeniedHandler handles ErrorSubTypeAccessDenied
type AccessDeniedHandler interface {
HandleAccessDenied(context.Context, *http.Request, http.ResponseWriter, error)
}
// AuthenticationErrorHandler handles ErrorTypeAuthentication
type AuthenticationErrorHandler interface {
HandleAuthenticationError(context.Context, *http.Request, http.ResponseWriter, error)
}
// AuthenticationEntryPoint kicks off authentication process
type AuthenticationEntryPoint interface {
Commence(context.Context, *http.Request, http.ResponseWriter, error)
}
// ErrorHandler handles any other type of errors
type ErrorHandler interface {
HandleError(context.Context, *http.Request, http.ResponseWriter, error)
}
/*****************************
Common Impl.
*****************************/
// CompositeAuthenticationErrorHandler implement AuthenticationErrorHandler interface
type CompositeAuthenticationErrorHandler struct {
init sync.Once
handlers []AuthenticationErrorHandler
flattened []AuthenticationErrorHandler
}
func NewAuthenticationErrorHandler(handlers ...AuthenticationErrorHandler) *CompositeAuthenticationErrorHandler {
ret := &CompositeAuthenticationErrorHandler{}
ret.handlers = ret.processErrorHandlers(handlers)
return ret
}
func (h *CompositeAuthenticationErrorHandler) HandleAuthenticationError(
c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
h.init.Do(func() { h.flattened = h.Handlers() })
for _, handler := range h.flattened {
handler.HandleAuthenticationError(c, r, rw, err)
}
}
// Handlers returns list of authentication handlers, any nested composite handlers are flattened
func (h *CompositeAuthenticationErrorHandler) Handlers() []AuthenticationErrorHandler {
flattened := make([]AuthenticationErrorHandler, 0, len(h.handlers))
for _, handler := range h.handlers {
switch h := handler.(type) {
case *CompositeAuthenticationErrorHandler:
flattened = append(flattened, h.Handlers()...)
default:
flattened = append(flattened, handler)
}
}
sort.SliceStable(flattened, func(i, j int) bool {
return order.OrderedFirstCompare(flattened[i], flattened[j])
})
return flattened
}
func (h *CompositeAuthenticationErrorHandler) Size() int {
return len(h.handlers)
}
func (h *CompositeAuthenticationErrorHandler) Add(handler AuthenticationErrorHandler) *CompositeAuthenticationErrorHandler {
h.handlers = h.processErrorHandlers(append(h.handlers, handler))
return h
}
func (h *CompositeAuthenticationErrorHandler) Merge(composite *CompositeAuthenticationErrorHandler) *CompositeAuthenticationErrorHandler {
h.handlers = h.processErrorHandlers(append(h.handlers, composite.handlers...))
return h
}
func (h *CompositeAuthenticationErrorHandler) processErrorHandlers(handlers []AuthenticationErrorHandler) []AuthenticationErrorHandler {
handlers = h.removeSelf(handlers)
handlers = h.removeDuplicates(handlers)
sort.SliceStable(handlers, func(i, j int) bool {
return order.OrderedFirstCompare(handlers[i], handlers[j])
})
return handlers
}
func (h *CompositeAuthenticationErrorHandler) removeSelf(items []AuthenticationErrorHandler) []AuthenticationErrorHandler {
count := 0
for _, item := range items {
if ptr, ok := item.(*CompositeAuthenticationErrorHandler); !ok || ptr != h {
// copy and increment index
items[count] = item
count++
}
}
// Prevent memory leak by erasing truncated values
for j := count; j < len(items); j++ {
items[j] = nil
}
return items[:count]
}
func (h *CompositeAuthenticationErrorHandler) removeDuplicates(items []AuthenticationErrorHandler) []AuthenticationErrorHandler {
lookup := map[AuthenticationErrorHandler]struct{}{}
unique := make([]AuthenticationErrorHandler, 0, len(items))
for _, v := range items {
if _, ok := lookup[v]; ok {
continue
}
lookup[v] = struct{}{}
unique = append(unique, v)
}
return unique
}
// CompositeAccessDeniedHandler implement AccessDeniedHandler interface
type CompositeAccessDeniedHandler struct {
init sync.Once
handlers []AccessDeniedHandler
flattened []AccessDeniedHandler
}
func NewAccessDeniedHandler(handlers ...AccessDeniedHandler) *CompositeAccessDeniedHandler {
ret := &CompositeAccessDeniedHandler{}
ret.handlers = ret.processErrorHandlers(handlers)
return ret
}
func (h *CompositeAccessDeniedHandler) HandleAccessDenied(
c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
h.init.Do(func() { h.flattened = h.Handlers()})
for _, handler := range h.flattened {
handler.HandleAccessDenied(c, r, rw, err)
}
}
// Handlers returns list of authentication handlers, any nested composite handlers are flattened
func (h *CompositeAccessDeniedHandler) Handlers() []AccessDeniedHandler {
flattened := make([]AccessDeniedHandler, 0, len(h.handlers))
for _, handler := range h.handlers {
switch h := handler.(type) {
case *CompositeAccessDeniedHandler:
flattened = append(flattened, h.Handlers()...)
default:
flattened = append(flattened, handler)
}
}
sort.SliceStable(flattened, func(i, j int) bool {
return order.OrderedFirstCompare(flattened[i], flattened[j])
})
return flattened
}
func (h *CompositeAccessDeniedHandler) Size() int {
return len(h.handlers)
}
func (h *CompositeAccessDeniedHandler) Add(handler AccessDeniedHandler) *CompositeAccessDeniedHandler {
h.handlers = h.processErrorHandlers(append(h.handlers, handler))
return h
}
func (h *CompositeAccessDeniedHandler) Merge(composite *CompositeAccessDeniedHandler) *CompositeAccessDeniedHandler {
h.handlers = h.processErrorHandlers(append(h.handlers, composite.handlers...))
return h
}
func (h *CompositeAccessDeniedHandler) processErrorHandlers(handlers []AccessDeniedHandler) []AccessDeniedHandler {
handlers = h.removeSelf(handlers)
handlers = h.removeDuplicates(handlers)
sort.SliceStable(handlers, func(i, j int) bool {
return order.OrderedFirstCompare(handlers[i], handlers[j])
})
return handlers
}
func (h *CompositeAccessDeniedHandler) removeSelf(items []AccessDeniedHandler) []AccessDeniedHandler {
count := 0
for _, item := range items {
if ptr, ok := item.(*CompositeAccessDeniedHandler); !ok || ptr != h {
// copy and increment index
items[count] = item
count++
}
}
// Prevent memory leak by erasing truncated values
for j := count; j < len(items); j++ {
items[j] = nil
}
return items[:count]
}
func (h *CompositeAccessDeniedHandler) removeDuplicates(items []AccessDeniedHandler) []AccessDeniedHandler {
lookup := map[AccessDeniedHandler]struct{}{}
unique := make([]AccessDeniedHandler, 0, len(items))
for _, v := range items {
if _, ok := lookup[v]; ok {
continue
}
lookup[v] = struct{}{}
unique = append(unique, v)
}
return unique
}
// *CompositeErrorHandler implement ErrorHandler interface
type CompositeErrorHandler struct {
handlers []ErrorHandler
}
func NewErrorHandler(handlers ...ErrorHandler) *CompositeErrorHandler {
ret := &CompositeErrorHandler{}
ret.handlers = ret.processErrorHandlers(handlers)
return ret
}
func (h *CompositeErrorHandler) HandleError(
c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
for _, handler := range h.handlers {
handler.HandleError(c, r, rw, err)
}
}
func (h *CompositeErrorHandler) Size() int {
return len(h.handlers)
}
func (h *CompositeErrorHandler) Add(handler ErrorHandler) *CompositeErrorHandler {
h.handlers = h.processErrorHandlers(append(h.handlers, handler))
return h
}
func (h *CompositeErrorHandler) Merge(composite *CompositeErrorHandler) *CompositeErrorHandler {
h.handlers = h.processErrorHandlers(append(h.handlers, composite.handlers...))
return h
}
func (h *CompositeErrorHandler) processErrorHandlers(handlers []ErrorHandler) []ErrorHandler {
handlers = h.removeSelf(handlers)
handlers = h.removeDuplicates(handlers)
sort.SliceStable(handlers, func(i, j int) bool {
return order.OrderedFirstCompare(handlers[i], handlers[j])
})
return handlers
}
func (h *CompositeErrorHandler) removeSelf(items []ErrorHandler) []ErrorHandler {
count := 0
for _, item := range items {
if ptr, ok := item.(*CompositeErrorHandler); !ok || ptr != h {
// copy and increment index
items[count] = item
count++
}
}
// Prevent memory leak by erasing truncated values
for j := count; j < len(items); j++ {
items[j] = nil
}
return items[:count]
}
func (h *CompositeErrorHandler) removeDuplicates(items []ErrorHandler) []ErrorHandler {
lookup := map[ErrorHandler]struct{}{}
unique := make([]ErrorHandler, 0, len(items))
for _, v := range items {
if _, ok := lookup[v]; ok {
continue
}
lookup[v] = struct{}{}
unique = append(unique, v)
}
return unique
}
/**************************
Default Impls
***************************/
// DefaultAccessDeniedHandler implements AccessDeniedHandler
type DefaultAccessDeniedHandler struct {
}
func (h *DefaultAccessDeniedHandler) HandleAccessDenied(ctx context.Context, r *http.Request, rw http.ResponseWriter, err error) {
switch {
case errors.Is(err, ErrorSubTypeInsufficientAuth):
WriteError(ctx, r, rw, http.StatusUnauthorized, err)
}
WriteError(ctx, r, rw, http.StatusForbidden, err)
}
// DefaultAuthenticationErrorHandler implements AuthenticationErrorHandler
type DefaultAuthenticationErrorHandler struct {
}
func (h *DefaultAuthenticationErrorHandler) HandleAuthenticationError(ctx context.Context, r *http.Request, rw http.ResponseWriter, err error) {
WriteError(ctx, r, rw, http.StatusUnauthorized, err)
}
// DefaultErrorHandler implements ErrorHandler
type DefaultErrorHandler struct{}
func (h *DefaultErrorHandler) HandleError(ctx context.Context, r *http.Request, rw http.ResponseWriter, err error) {
WriteError(ctx, r, rw, http.StatusUnauthorized, err)
}
/**************************
Common Functions
***************************/
func WriteError(ctx context.Context, r *http.Request, rw http.ResponseWriter, code int, err error) {
if IsResponseWritten(rw) {
return
}
if isJson(r) {
WriteErrorAsJson(ctx, rw, code, err)
} else {
WriteErrorAsHtml(ctx, rw, code, err)
}
}
func IsResponseWritten(rw http.ResponseWriter) bool {
ginRw, ok := rw.(gin.ResponseWriter)
return ok && ginRw.Written()
}
func WriteErrorAsHtml(ctx context.Context, rw http.ResponseWriter, code int, err error) {
httpError := web.NewHttpError(code, err)
template.TemplateErrorEncoder(ctx, httpError, rw)
}
func WriteErrorAsJson(ctx context.Context, rw http.ResponseWriter, code int, err error) {
httpError := web.NewHttpError(code, err)
web.JsonErrorEncoder()(ctx, httpError, rw)
}
/**************************
Helpers
***************************/
func isJson(r *http.Request) bool {
// TODO should be more comprehensive than this
accept := r.Header.Get("Accept")
contentType := r.Header.Get("Content-Type")
return strings.Contains(accept, "application/json") || strings.Contains(contentType, "application/json")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package errorhandling
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
var (
FeatureId = security.FeatureId("ErrorHandling", security.FeatureOrderErrorHandling)
)
// We currently don't have any stuff to configure
//goland:noinspection GoNameStartsWithPackageName
type ErrorHandlingFeature struct {
authEntryPoint security.AuthenticationEntryPoint
accessDeniedHandler security.AccessDeniedHandler
authErrorHandler security.AuthenticationErrorHandler
errorHandler *security.CompositeErrorHandler
}
// Standard security.Feature entrypoint
func (f *ErrorHandlingFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
func (f *ErrorHandlingFeature) AuthenticationEntryPoint(v security.AuthenticationEntryPoint) *ErrorHandlingFeature {
f.authEntryPoint = v
return f
}
func (f *ErrorHandlingFeature) AccessDeniedHandler(v security.AccessDeniedHandler) *ErrorHandlingFeature {
f.accessDeniedHandler = v
return f
}
func (f *ErrorHandlingFeature) AuthenticationErrorHandler(v security.AuthenticationErrorHandler) *ErrorHandlingFeature {
f.authErrorHandler = v
return f
}
// AdditionalErrorHandler add security.ErrorHandler to existing list.
// This value is typically used by other features, because there are no other type of concrete errors except for
// AuthenticationError and AccessControlError,
// which are handled by AccessDeniedHandler, AuthenticationErrorHandler and AuthenticationEntryPoint
func (f *ErrorHandlingFeature) AdditionalErrorHandler(v security.ErrorHandler) *ErrorHandlingFeature {
f.errorHandler.Add(v)
return f
}
func Configure(ws security.WebSecurity) *ErrorHandlingFeature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*ErrorHandlingFeature)
}
panic(fmt.Errorf("unable to configure session: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *ErrorHandlingFeature {
return &ErrorHandlingFeature{
errorHandler: security.NewErrorHandler(),
}
}
//goland:noinspection GoNameStartsWithPackageName
type ErrorHandlingConfigurer struct {
}
func newErrorHandlingConfigurer() *ErrorHandlingConfigurer {
return &ErrorHandlingConfigurer{
}
}
func (ehc *ErrorHandlingConfigurer) Apply(feature security.Feature, ws security.WebSecurity) error {
// Verify
if err := ehc.validate(feature.(*ErrorHandlingFeature), ws); err != nil {
return err
}
f := feature.(*ErrorHandlingFeature)
authErrorHandler := ws.Shared(security.WSSharedKeyCompositeAuthErrorHandler).(*security.CompositeAuthenticationErrorHandler)
authErrorHandler.Add(f.authErrorHandler)
accessDeniedHandler := ws.Shared(security.WSSharedKeyCompositeAccessDeniedHandler).(*security.CompositeAccessDeniedHandler)
accessDeniedHandler.Add(f.accessDeniedHandler)
mw := NewErrorHandlingMiddleware()
mw.entryPoint = f.authEntryPoint
mw.accessDeniedHandler = accessDeniedHandler
mw.authErrorHandler = authErrorHandler
mw.errorHandler = f.errorHandler
errHandler := middleware.NewBuilder("error handling").
Order(security.MWOrderErrorHandling).
Use(mw.HandlerFunc())
ws.Add(errHandler)
return nil
}
func (ehc *ErrorHandlingConfigurer) validate(f *ErrorHandlingFeature, ws security.WebSecurity) error {
if f.authEntryPoint == nil {
logger.WithContext(ws.Context()).Infof("authentication entry point is not set, fallback to access denied handler - [%v], ", log.Capped(ws, 80))
}
if f.authErrorHandler == nil {
logger.WithContext(ws.Context()).Infof("using default authentication error handler - [%v]", log.Capped(ws, 80))
f.authErrorHandler = &security.DefaultAuthenticationErrorHandler{}
}
if f.accessDeniedHandler == nil {
logger.WithContext(ws.Context()).Infof("using default access denied handler - [%v]", log.Capped(ws, 80))
f.accessDeniedHandler = &security.DefaultAccessDeniedHandler{}
}
// always add a default to the end. note: DefaultErrorHandler is unordered
f.errorHandler.Add(&security.DefaultErrorHandler{})
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package errorhandling
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/template"
"net/http"
)
func ErrorWithStatus(ctx context.Context, _ web.EmptyRequest) (int, *template.ModelView, error) {
s := session.Get(ctx)
if s == nil {
err := fmt.Errorf("error message not available")
return http.StatusInternalServerError, nil, err
}
code, codeOk := s.Flash(redirect.FlashKeyPreviousStatusCode).(int)
if !codeOk {
code = 500
}
err, errOk := s.Flash(redirect.FlashKeyPreviousError).(error)
if !errOk {
err = errors.New("unknown error")
}
return code, nil, web.NewHttpError(code, err)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package errorhandling
import (
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
"github.com/gin-gonic/gin"
"strings"
)
//goland:noinspection GoNameStartsWithPackageName
type ErrorHandlingMiddleware struct {
entryPoint security.AuthenticationEntryPoint
accessDeniedHandler security.AccessDeniedHandler
authErrorHandler security.AuthenticationErrorHandler
errorHandler security.ErrorHandler
}
func NewErrorHandlingMiddleware() *ErrorHandlingMiddleware {
return &ErrorHandlingMiddleware{}
}
func (eh *ErrorHandlingMiddleware) HandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
defer eh.tryRecover(ctx)
ctx.Next()
eh.tryHandleErrors(ctx)
}
}
func (eh *ErrorHandlingMiddleware) tryRecover(c *gin.Context) {
r := recover()
if r == nil {
// nothing to recover
return
}
err,ok := r.(error)
if !ok || !errors.Is(err, security.ErrorTypeSecurity) {
// not something we can handle
panic(r)
}
eh.handleError(c, err)
}
func (eh *ErrorHandlingMiddleware) tryHandleErrors(c *gin.Context) {
// find first security error and handle it
for _,e := range c.Errors {
if errors.Is(e.Err, security.ErrorTypeSecurity) {
eh.handleError(c, e.Err)
break
}
}
}
func (eh *ErrorHandlingMiddleware) handleError(c *gin.Context, err error) {
eh.logError(c, err)
if c.Writer.Written() {
return
}
// we assume eh.authErrorHandler and eh.accessDeniedHandler is always not nil (guaranteed by ErrorHandlingConfigurer)
switch {
case errors.Is(err, security.ErrorTypeInternal):
eh.authErrorHandler.HandleAuthenticationError(c, c.Request, c.Writer, err)
case eh.entryPoint != nil && errors.Is(err, security.ErrorSubTypeInsufficientAuth):
eh.entryPoint.Commence(c, c.Request, c.Writer, err)
case errors.Is(err, security.ErrorTypeAuthentication):
eh.authErrorHandler.HandleAuthenticationError(c, c.Request, c.Writer, err)
case errors.Is(err, security.ErrorTypeAccessControl):
eh.accessDeniedHandler.HandleAccessDenied(c, c.Request, c.Writer, err)
default:
eh.errorHandler.HandleError(c, c.Request, c.Writer, err)
}
}
//nolint:errorlint
func (eh *ErrorHandlingMiddleware) logError(c *gin.Context, err error) {
var errMsgs []string
for cause, isNested := err, true; isNested && cause != nil; {
errMsgs = append(errMsgs, cause.Error())
var nested errorutils.NestedError
if nested, isNested = cause.(errorutils.NestedError); isNested {
cause = nested.Cause()
}
}
msg := strings.Join(errMsgs, " - [Caused By]: ")
logger.WithContext(c.Request.Context()).Debugf("[Error]: %s", msg)
}
/**************************
Helpers
***************************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package errorhandling
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
var logger = log.New("SEC.Err")
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "error handling",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newErrorHandlingConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"errors"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
)
const (
// security reserved
Reserved = 11 << errorutils.ReservedOffset
)
// All "Type" values are used as mask
const (
_ = iota
ErrorTypeCodeAuthentication = Reserved + iota<<errorutils.ErrorTypeOffset
ErrorTypeCodeAccessControl
ErrorTypeCodeInternal
ErrorTypeCodeOAuth2
ErrorTypeCodeSaml
ErrorTypeCodeOidc
ErrorTypeCodeTenancy
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeAuthentication
const (
_ = iota
ErrorSubTypeCodeInternal = ErrorTypeCodeAuthentication + iota<<errorutils.ErrorSubTypeOffset
ErrorSubTypeCodeUsernamePasswordAuth
ErrorSubTypeCodeExternalSamlAuth
ErrorSubTypeCodeAuthWarning
)
// ErrorSubTypeCodeInternal
const (
_ = iota
ErrorCodeAuthenticatorNotAvailable = ErrorSubTypeCodeInternal + iota
)
// ErrorSubTypeCodeUsernamePasswordAuth
const (
_ = iota
ErrorCodeUsernameNotFound = ErrorSubTypeCodeUsernamePasswordAuth + iota
ErrorCodeBadCredentials
ErrorCodeCredentialsExpired
ErrorCodeMaxAttemptsReached
ErrorCodeAccountStatus
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeAccessControl
const (
_ = iota
ErrorSubTypeCodeAccessDenied = ErrorTypeCodeAccessControl + iota<<errorutils.ErrorSubTypeOffset
ErrorSubTypeCodeInsufficientAuth
ErrorSubTypeCodeCsrf
)
// All "SubType" values are used as mask
// sub types of ErrorTypeCodeTenancy
const (
_ = iota
ErrorSubTypeCodeTenantInvalid = ErrorTypeCodeTenancy + iota<<errorutils.ErrorSubTypeOffset
ErrorSubTypeCodeTenantAccessDenied
)
const (
_ = iota
ErrorCodeMissingCsrfToken = ErrorSubTypeCodeCsrf + iota
ErrorCodeInvalidCsrfToken
)
// ErrorTypes, can be used in errors.Is
var (
ErrorTypeSecurity = NewErrorCategory(Reserved, errors.New("error type: security"))
ErrorTypeAuthentication = NewErrorType(ErrorTypeCodeAuthentication, errors.New("error type: authentication"))
ErrorTypeAccessControl = NewErrorType(ErrorTypeCodeAccessControl, errors.New("error type: access control"))
ErrorTypeInternal = NewErrorType(ErrorTypeCodeInternal, errors.New("error type: internal"))
ErrorTypeSaml = NewErrorType(ErrorTypeCodeSaml, errors.New("error type: saml"))
ErrorTypeOidc = NewErrorType(ErrorTypeCodeOidc, errors.New("error type: oidc"))
ErrorSubTypeInternalError = NewErrorSubType(ErrorSubTypeCodeInternal, errors.New("error sub-type: internal"))
ErrorSubTypeUsernamePasswordAuth = NewErrorSubType(ErrorSubTypeCodeUsernamePasswordAuth, errors.New("error sub-type: internal"))
ErrorSubTypeExternalSamlAuth = NewErrorSubType(ErrorSubTypeCodeExternalSamlAuth, errors.New("error sub-type: external saml"))
ErrorSubTypeAuthWarning = NewErrorSubType(ErrorSubTypeCodeAuthWarning, errors.New("error sub-type: auth warning"))
ErrorSubTypeAccessDenied = NewErrorSubType(ErrorSubTypeCodeAccessDenied, errors.New("error sub-type: access denied"))
ErrorSubTypeInsufficientAuth = NewErrorSubType(ErrorSubTypeCodeInsufficientAuth, errors.New("error sub-type: insufficient auth"))
ErrorSubTypeCsrf = NewErrorSubType(ErrorSubTypeCodeCsrf, errors.New("error sub-type: csrf"))
)
// Concrete error, can be used in errors.Is for exact match
var (
ErrorInvalidTenantId = NewCodedError(ErrorSubTypeCodeTenantInvalid, "Invalid tenant Id")
ErrorTenantAccessDenied = NewCodedError(ErrorSubTypeCodeTenantAccessDenied, "No Access to the tenant")
)
func init() {
errorutils.Reserve(ErrorTypeSecurity)
}
// CodedError implements errorutils.ErrorCoder, errorutils.ComparableErrorCoder, errorutils.NestedError
type CodedError struct {
errorutils.CodedError
}
/************************
Constructors
*************************/
// NewCodedError creates concrete error. it cannot be used as ErrorType or ErrorSubType comparison
// supported item are string, error, fmt.Stringer
func NewCodedError(code int64, e interface{}, causes ...interface{}) *CodedError {
return &CodedError{
CodedError: *errorutils.NewCodedError(code, e, causes...),
}
}
func NewErrorCategory(code int64, e error) *CodedError {
return &CodedError{
CodedError: *errorutils.NewErrorCategory(code, e),
}
}
func NewErrorType(code int64, e error) error {
return errorutils.NewErrorType(code, e)
}
func NewErrorSubType(code int64, e error) error {
return errorutils.NewErrorSubType(code, e)
}
/* InternalError family */
func NewInternalError(text string, causes ...interface{}) error {
return NewCodedError(ErrorTypeCodeInternal, errors.New(text), causes...)
}
/* AuthenticationError family */
func NewAuthenticationError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorTypeCodeAuthentication, value, causes...)
}
func NewInternalAuthenticationError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorSubTypeCodeInternal, value, causes...)
}
func NewAuthenticationWarningError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorSubTypeCodeAuthWarning, value, causes...)
}
func NewAuthenticatorNotAvailableError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorCodeAuthenticatorNotAvailable, value, causes...)
}
func NewExternalSamlAuthenticationError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorSubTypeCodeExternalSamlAuth, value, causes...)
}
func NewUsernameNotFoundError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorCodeUsernameNotFound, value, causes...)
}
func NewBadCredentialsError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorCodeBadCredentials, value, causes...)
}
func NewCredentialsExpiredError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorCodeCredentialsExpired, value, causes...)
}
func NewMaxAttemptsReachedError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorCodeMaxAttemptsReached, value, causes...)
}
func NewAccountStatusError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorCodeAccountStatus, value, causes...)
}
/* AccessControlError family */
func NewAccessControlError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorTypeCodeAccessControl, value, causes...)
}
func NewAccessDeniedError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorSubTypeCodeAccessDenied, value, causes...)
}
func NewInsufficientAuthError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorSubTypeCodeInsufficientAuth, value, causes...)
}
func NewMissingCsrfTokenError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorCodeMissingCsrfToken, value, causes...)
}
func NewInvalidCsrfTokenError(value interface{}, causes ...interface{}) error {
return NewCodedError(ErrorCodeInvalidCsrfToken, value, causes...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package formlogin
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/csrf"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/security/request_cache"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"net/http"
)
var (
FeatureId = security.FeatureId("FormLogin", security.FeatureOrderFormLogin)
)
//goland:noinspection GoNameStartsWithPackageName
type FormLoginConfigurer struct {
serverProps web.ServerProperties
configured bool
}
func newFormLoginConfigurer(serverProps web.ServerProperties) *FormLoginConfigurer {
return &FormLoginConfigurer{
serverProps: serverProps,
}
}
func (c *FormLoginConfigurer) Apply(feature security.Feature, ws security.WebSecurity) error {
// Verify
if err := c.validate(feature.(*FormLoginFeature), ws); err != nil {
return err
}
f := feature.(*FormLoginFeature)
if err := c.configureErrorHandling(f, ws); err != nil {
return err
}
if c.configured {
logger.WithContext(ws.Context()).Warnf(`attempting to reconfigure login forms for WebSecurity [%v]. `+
`Changes will not be applied. If this is expected, please ignore this warning`, ws)
return nil
}
c.configured = true
if err := c.configureLoginPage(f, ws); err != nil {
return err
}
if err := c.configureMfaPage(f, ws); err != nil {
return err
}
if err := c.configureLoginProcessing(f, ws); err != nil {
return err
}
if err := c.configureMfaProcessing(f, ws); err != nil {
return err
}
if err := c.configureCSRF(f, ws); err != nil {
return err
}
return nil
}
func (c *FormLoginConfigurer) validate(f *FormLoginFeature, _ security.WebSecurity) error {
if f.loginUrl == "" {
return fmt.Errorf("loginUrl is missing for form login")
}
if f.successHandler == nil {
f.successHandler = c.defaultSavedRequestSuccessHandler()
}
if f.loginProcessUrl == "" {
f.loginProcessUrl = f.loginUrl
}
if f.loginErrorUrl == "" && f.failureHandler == nil {
f.loginErrorUrl = f.loginUrl + "?error=true"
}
if f.mfaEnabled && f.mfaUrl == "" {
return fmt.Errorf("mfaUrl is missing for MFA")
}
if f.mfaEnabled && f.mfaVerifyUrl == "" {
f.mfaVerifyUrl = f.mfaUrl
}
if f.mfaErrorUrl == "" && f.failureHandler == nil {
f.mfaErrorUrl = f.mfaUrl + "?error=true"
}
return nil
}
func (c *FormLoginConfigurer) configureErrorHandling(f *FormLoginFeature, ws security.WebSecurity) error {
errorRedirect := redirect.NewRedirectWithURL(f.loginErrorUrl)
mfaErrorRedirect := redirect.NewRedirectWithURL(f.mfaErrorUrl)
if f.failureHandler == nil {
f.failureHandler = errorRedirect
}
var entryPoint security.AuthenticationEntryPoint = redirect.NewRedirectWithURL(f.loginUrl)
if f.mfaEnabled {
if _, ok := f.failureHandler.(*MfaAwareAuthenticationErrorHandler); !ok {
f.failureHandler = &MfaAwareAuthenticationErrorHandler{
delegate: f.failureHandler,
mfaPendingDelegate: mfaErrorRedirect,
}
}
entryPoint = &MfaAwareAuthenticationEntryPoint{
delegate: entryPoint,
mfaPendingDelegate: redirect.NewRedirectWithURL(f.mfaUrl),
}
}
// override entry point and error handler
errorhandling.Configure(ws).
AuthenticationEntryPoint(request_cache.NewSaveRequestEntryPoint(entryPoint)).
AuthenticationErrorHandler(f.failureHandler)
// adding CSRF protection err handler, while keeping default
csrf.Configure(ws).CsrfDeniedHandler(errorRedirect)
return nil
}
func (c *FormLoginConfigurer) configureLoginPage(f *FormLoginFeature, ws security.WebSecurity) error {
// let ws know to intercept additional url
routeMatcher := matcher.RouteWithURL(f.loginUrl, http.MethodGet)
requestMatcher := matcher.RequestWithURL(f.loginUrl, http.MethodGet)
ws.Route(routeMatcher)
// configure access
access.Configure(ws).
Request(requestMatcher).WithOrder(order.Highest).PermitAll()
return nil
}
func (c *FormLoginConfigurer) configureMfaPage(f *FormLoginFeature, ws security.WebSecurity) error {
// let ws know to intercept additional url
routeMatcher := matcher.RouteWithURL(f.mfaUrl, http.MethodGet)
requestMatcher := matcher.RequestWithURL(f.mfaUrl, http.MethodGet)
ws.Route(routeMatcher)
// configure access
access.Configure(ws).
Request(requestMatcher).WithOrder(order.Highest).
HasPermissions(passwd.SpecialPermissionMFAPending, passwd.SpecialPermissionOtpId)
return nil
}
func (c *FormLoginConfigurer) configureLoginProcessing(f *FormLoginFeature, ws security.WebSecurity) error {
// let ws know to intercept additional url
route := matcher.RouteWithURL(f.loginProcessUrl, http.MethodPost)
ws.Route(route)
// configure middlewares
// Note: since this MW handles a new path, we add middleware as-is instead of a security.MiddlewareTemplate
login := NewFormAuthenticationMiddleware(func(opts *FormAuthMWOptions) {
opts.Authenticator = ws.Authenticator()
opts.SuccessHandler = c.effectiveSuccessHandler(f, ws)
opts.UsernameParam = f.usernameParam
opts.PasswordParam = f.passwordParam
})
mw := middleware.NewBuilder("form login").
ApplyTo(route).
Order(security.MWOrderFormAuth).
Use(login.LoginProcessHandlerFunc())
ws.Add(mw)
// configure additional endpoint mappings to trigger middleware
ws.Add(mapping.Post(f.loginProcessUrl).
HandlerFunc(security.NoopHandlerFunc()).
Name("login process dummy"))
return nil
}
func (c *FormLoginConfigurer) configureMfaProcessing(f *FormLoginFeature, ws security.WebSecurity) error {
// let ws know to intercept additional url
routeVerify := matcher.RouteWithURL(f.mfaVerifyUrl, http.MethodPost)
routeRefresh := matcher.RouteWithURL(f.mfaRefreshUrl, http.MethodPost)
requestMatcher := matcher.RequestWithURL(f.mfaVerifyUrl, http.MethodPost).
Or(matcher.RequestWithURL(f.mfaRefreshUrl, http.MethodPost))
ws.Route(routeVerify).Route(routeRefresh)
// configure middlewares
// Note: since this MW handles a new path, we add middleware as-is instead of a security.MiddlewareTemplate
login := NewMfaAuthenticationMiddleware(func(opts *MfaMWOptions) {
opts.Authenticator = ws.Authenticator()
opts.SuccessHandler = c.effectiveSuccessHandler(f, ws)
opts.OtpParam = f.otpParam
})
verifyMW := middleware.NewBuilder("otp verify").
ApplyTo(routeVerify).
Order(security.MWOrderFormAuth).
Use(login.OtpVerifyHandlerFunc())
refreshMW := middleware.NewBuilder("otp refresh").
ApplyTo(routeRefresh).
Order(security.MWOrderFormAuth).
Use(login.OtpRefreshHandlerFunc())
ws.Add(verifyMW, refreshMW)
// configure additional endpoint mappings to trigger middleware
ws.Add(mapping.Post(f.mfaVerifyUrl).
HandlerFunc(security.NoopHandlerFunc()).
Name("otp verify dummy"))
ws.Add(mapping.Post(f.mfaRefreshUrl).
HandlerFunc(security.NoopHandlerFunc()).
Name("otp refresh dummy"))
// configure access
access.Configure(ws).
Request(requestMatcher).WithOrder(order.Highest).
HasPermissions(passwd.SpecialPermissionMFAPending, passwd.SpecialPermissionOtpId)
return nil
}
func (c *FormLoginConfigurer) configureCSRF(f *FormLoginFeature, ws security.WebSecurity) error {
csrfMatcher := matcher.RequestWithURL(f.loginProcessUrl, http.MethodPost).
Or(matcher.RequestWithURL(f.mfaVerifyUrl, http.MethodPost)).
Or(matcher.RequestWithURL(f.mfaRefreshUrl, http.MethodPost))
csrf.Configure(ws).AddCsrfProtectionMatcher(csrfMatcher)
return nil
}
func (c *FormLoginConfigurer) effectiveSuccessHandler(f *FormLoginFeature, ws security.WebSecurity) security.AuthenticationSuccessHandler {
if _, ok := f.successHandler.(*MfaAwareSuccessHandler); f.mfaEnabled && !ok {
f.successHandler = &MfaAwareSuccessHandler{
delegate: f.successHandler,
mfaPendingDelegate: redirect.NewRedirectWithURL(f.mfaUrl),
}
}
rememberUsernameHandler := newRememberUsernameSuccessHandler(func(h *RememberUsernameSuccessHandler) {
h.contextPath = c.serverProps.ContextPath
h.rememberParam = f.rememberParam
h.cookieDomain = f.rememberCookieDomain
h.cookieSecured = f.rememberCookieSecured
h.cookieHttpOnly = !f.rememberCookieSecured
h.cookieMaxAge = f.rememberCookieValidity
})
if globalHandler, ok := ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(security.AuthenticationSuccessHandler); ok {
return security.NewAuthenticationSuccessHandler(globalHandler, rememberUsernameHandler, f.successHandler)
} else {
return security.NewAuthenticationSuccessHandler(rememberUsernameHandler, f.successHandler)
}
}
func (c *FormLoginConfigurer) defaultSavedRequestSuccessHandler() security.AuthenticationSuccessHandler {
return request_cache.NewSavedRequestAuthenticationSuccessHandler(
redirect.NewRedirectWithRelativePath("/", true),
func(from, to security.Authentication) bool {
// Note: we changed the condition from security.IsBeingAuthenticated(from, to) to security.IsFullyAuthenticated(to)
// to handle re-authenticate cases:
// We allow authenticated user to re-authenticate with same or different user credentials,
// in such case security.IsBeingAuthenticated would skip the redirect
return security.IsFullyAuthenticated(to)
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package formlogin
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"time"
)
/*********************************
Feature Impl
*********************************/
//goland:noinspection GoNameStartsWithPackageName
type FormLoginFeature struct {
successHandler security.AuthenticationSuccessHandler
failureHandler security.AuthenticationErrorHandler
loginUrl string
loginProcessUrl string
loginErrorUrl string
usernameParam string
passwordParam string
rememberCookieDomain string
rememberCookieSecured bool
rememberCookieValidity time.Duration
rememberParam string
mfaEnabled bool
mfaUrl string
mfaVerifyUrl string
mfaRefreshUrl string
mfaErrorUrl string
otpParam string
}
func (f *FormLoginFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
func (f *FormLoginFeature) LoginUrl(loginUrl string) *FormLoginFeature {
f.loginUrl = loginUrl
return f
}
func (f *FormLoginFeature) LoginProcessUrl(loginProcessUrl string) *FormLoginFeature {
f.loginProcessUrl = loginProcessUrl
return f
}
func (f *FormLoginFeature) LoginErrorUrl(loginErrorUrl string) *FormLoginFeature {
f.loginErrorUrl = loginErrorUrl
return f
}
func (f *FormLoginFeature) UsernameParameter(usernameParam string) *FormLoginFeature {
f.usernameParam = usernameParam
return f
}
func (f *FormLoginFeature) PasswordParameter(passwordParam string) *FormLoginFeature {
f.passwordParam = passwordParam
return f
}
func (f *FormLoginFeature) RememberParameter(rememberParam string) *FormLoginFeature {
f.rememberParam = rememberParam
return f
}
func (f *FormLoginFeature) RememberCookieDomain(v string) *FormLoginFeature {
f.rememberCookieDomain = v
return f
}
func (f *FormLoginFeature) RememberCookieSecured(v bool) *FormLoginFeature {
f.rememberCookieSecured = v
return f
}
func (f *FormLoginFeature) RememberCookieValidity(v time.Duration) *FormLoginFeature {
f.rememberCookieValidity = v
return f
}
// SuccessHandler overrides LoginSuccessUrl
func (f *FormLoginFeature) SuccessHandler(successHandler security.AuthenticationSuccessHandler) *FormLoginFeature {
f.successHandler = successHandler
return f
}
// FailureHandler overrides LoginErrorUrl
func (f *FormLoginFeature) FailureHandler(failureHandler security.AuthenticationErrorHandler) *FormLoginFeature {
f.failureHandler = failureHandler
return f
}
func (f *FormLoginFeature) EnableMFA() *FormLoginFeature {
f.mfaEnabled = true
return f
}
func (f *FormLoginFeature) MfaUrl(mfaUrl string) *FormLoginFeature {
f.mfaUrl = mfaUrl
return f
}
func (f *FormLoginFeature) MfaVerifyUrl(mfaVerifyUrl string) *FormLoginFeature {
f.mfaVerifyUrl = mfaVerifyUrl
return f
}
func (f *FormLoginFeature) MfaRefreshUrl(mfaRefreshUrl string) *FormLoginFeature {
f.mfaRefreshUrl = mfaRefreshUrl
return f
}
func (f *FormLoginFeature) MfaErrorUrl(mfaErrorUrl string) *FormLoginFeature {
f.mfaErrorUrl = mfaErrorUrl
return f
}
func (f *FormLoginFeature) OtpParameter(otpParam string) *FormLoginFeature {
f.otpParam = otpParam
return f
}
/*********************************
Constructors and Configure
*********************************/
func Configure(ws security.WebSecurity) *FormLoginFeature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*FormLoginFeature)
}
panic(fmt.Errorf("unable to configure form login: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// New is Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *FormLoginFeature {
return &FormLoginFeature{
loginUrl: "/login",
loginProcessUrl: "/login",
loginErrorUrl: "/login?error=true",
usernameParam: "username",
passwordParam: "password",
rememberParam: "remember-me",
rememberCookieValidity: time.Hour,
mfaUrl: "/login/mfa",
mfaVerifyUrl: "/login/mfa",
mfaRefreshUrl: "/login/mfa/refresh",
mfaErrorUrl: "/login/mfa?error=true",
otpParam: "otp",
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package formlogin
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"net/http"
)
type MfaAwareAuthenticationEntryPoint struct {
delegate security.AuthenticationEntryPoint
mfaPendingDelegate security.AuthenticationEntryPoint
}
func (h *MfaAwareAuthenticationEntryPoint) Commence(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
auth,ok := security.Get(c).(passwd.UsernamePasswordAuthentication)
if ok && auth.IsMFAPending() {
h.mfaPendingDelegate.Commence(c, r, rw, err)
} else {
h.delegate.Commence(c, r, rw, err)
}
}
type MfaAwareSuccessHandler struct {
delegate security.AuthenticationSuccessHandler
mfaPendingDelegate security.AuthenticationSuccessHandler
}
func (h *MfaAwareSuccessHandler) HandleAuthenticationSuccess(
c context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
userAuth,ok := to.(passwd.UsernamePasswordAuthentication)
if ok && userAuth.IsMFAPending() {
h.mfaPendingDelegate.HandleAuthenticationSuccess(c, r, rw, from, to)
} else {
h.delegate.HandleAuthenticationSuccess(c, r, rw, from, to)
}
}
type MfaAwareAuthenticationErrorHandler struct {
delegate security.AuthenticationErrorHandler
mfaPendingDelegate security.AuthenticationErrorHandler
}
func (h *MfaAwareAuthenticationErrorHandler) HandleAuthenticationError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
auth,ok := security.Get(c).(passwd.UsernamePasswordAuthentication)
if ok && auth.IsMFAPending() {
h.mfaPendingDelegate.HandleAuthenticationError(c, r, rw, err)
} else {
h.delegate.HandleAuthenticationError(c, r, rw, err)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package formlogin
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/gin-gonic/gin"
)
type FormAuthenticationMiddleware struct {
authenticator security.Authenticator
successHandler security.AuthenticationSuccessHandler
usernameParam string
passwordParam string
}
type FormAuthMWOptionsFunc func(*FormAuthMWOptions)
type FormAuthMWOptions struct {
Authenticator security.Authenticator
SuccessHandler security.AuthenticationSuccessHandler
UsernameParam string
PasswordParam string
}
func NewFormAuthenticationMiddleware(optionFuncs... FormAuthMWOptionsFunc) *FormAuthenticationMiddleware {
opts := FormAuthMWOptions{}
for _, optFunc := range optionFuncs {
if optFunc != nil {
optFunc(&opts)
}
}
return &FormAuthenticationMiddleware{
authenticator: opts.Authenticator,
successHandler: opts.SuccessHandler,
usernameParam: opts.UsernameParam,
passwordParam: opts.PasswordParam,
}
}
func (mw *FormAuthenticationMiddleware) LoginProcessHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
username := ctx.PostFormArray(mw.usernameParam)
if len(username) == 0 {
username = []string{""}
}
password := ctx.PostFormArray(mw.passwordParam)
if len(password) == 0 {
password = []string{""}
}
before := security.Get(ctx)
currentAuth, ok := before.(passwd.UsernamePasswordAuthentication)
//nolint:staticcheck // empty block for document purpose
if ok && passwd.IsSamePrincipal(username[0], currentAuth) {
// We currently allow re-authenticate without logout.
// If we don't want to allow it, we need to figure out how to error out without clearing the authentication.
// Note: currently, clearing authentication happens in error handling middleware on all SecurityAuthenticationError
}
candidate := passwd.UsernamePasswordPair{
Username: username[0],
Password: password[0],
EnforceMFA: passwd.MFAModeOptional,
}
// Authenticate
auth, err := mw.authenticator.Authenticate(ctx, &candidate)
if err != nil {
mw.handleError(ctx, err, &candidate)
return
}
mw.handleSuccess(ctx, before, auth)
}
}
func (mw *FormAuthenticationMiddleware) handleSuccess(c *gin.Context, before, new security.Authentication) {
if new != nil {
security.MustSet(c, new)
}
mw.successHandler.HandleAuthenticationSuccess(c, c.Request, c.Writer, before, new)
if c.Writer.Written() {
c.Abort()
}
}
func (mw *FormAuthenticationMiddleware) handleError(c *gin.Context, err error, candidate security.Candidate) {
security.MustClear(c)
if candidate != nil {
s := session.Get(c)
if s != nil {
s.AddFlash(candidate.Principal(), mw.usernameParam)
}
}
_ = c.Error(err)
c.Abort()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package formlogin
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
"github.com/gin-gonic/gin"
)
var (
)
type MfaAuthenticationMiddleware struct {
authenticator security.Authenticator
successHandler security.AuthenticationSuccessHandler
otpParam string
}
type MfaMWOptionsFunc func(*MfaMWOptions)
type MfaMWOptions struct {
Authenticator security.Authenticator
SuccessHandler security.AuthenticationSuccessHandler
OtpParam string
}
func NewMfaAuthenticationMiddleware(optionFuncs ...MfaMWOptionsFunc) *MfaAuthenticationMiddleware {
options := MfaMWOptions{}
for _, optFunc := range optionFuncs {
if optFunc != nil {
optFunc(&options)
}
}
return &MfaAuthenticationMiddleware{
authenticator: options.Authenticator,
successHandler: options.SuccessHandler,
otpParam: options.OtpParam,
}
}
func (mw *MfaAuthenticationMiddleware) OtpVerifyHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
otp := ctx.PostFormArray(mw.otpParam)
if len(otp) == 0 {
otp = []string{""}
}
before, err := mw.currentAuth(ctx)
if err != nil {
mw.handleError(ctx, err, nil)
return
}
candidate := passwd.MFAOtpVerification{
CurrentAuth: before,
OTP: otp[0],
DetailsMap: map[string]interface{}{},
}
// authenticate
auth, err := mw.authenticator.Authenticate(ctx, &candidate)
if err != nil {
mw.handleError(ctx, err, &candidate)
return
}
mw.handleSuccess(ctx, before, auth)
}
}
func (mw *MfaAuthenticationMiddleware) OtpRefreshHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
before, err := mw.currentAuth(ctx)
if err != nil {
mw.handleError(ctx, err, nil)
return
}
candidate := passwd.MFAOtpRefresh{
CurrentAuth: before,
DetailsMap: map[string]interface{}{},
}
// authenticate
auth, err := mw.authenticator.Authenticate(ctx, &candidate)
if err != nil {
mw.handleError(ctx, err, &candidate)
return
}
mw.handleSuccess(ctx, before, auth)
}
}
func (mw *MfaAuthenticationMiddleware) currentAuth(ctx *gin.Context) (passwd.UsernamePasswordAuthentication, error) {
if currentAuth, ok := security.Get(ctx).(passwd.UsernamePasswordAuthentication); !ok || !currentAuth.IsMFAPending() {
return nil, security.NewAccessDeniedError("MFA is not in progess")
} else {
return currentAuth, nil
}
}
func (mw *MfaAuthenticationMiddleware) handleSuccess(c *gin.Context, before, new security.Authentication) {
if new != nil {
security.MustSet(c, new)
}
mw.successHandler.HandleAuthenticationSuccess(c, c.Request, c.Writer, before, new)
if c.Writer.Written() {
c.Abort()
}
}
func (mw *MfaAuthenticationMiddleware) handleError(c *gin.Context, err error, candidate security.Candidate) {
if mw.shouldClear(err) {
security.MustClear(c)
}
_ = c.Error(err)
c.Abort()
}
func (mw *MfaAuthenticationMiddleware) shouldClear(err error) bool {
//nolint:errorlint
switch coder, ok := err.(errorutils.ErrorCoder); ok {
case coder.Code() == security.ErrorCodeCredentialsExpired:
return true
case coder.Code() == security.ErrorCodeMaxAttemptsReached:
return true
}
return false
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package formlogin
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
)
var logger = log.New("SEC.Login")
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "form login",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
ServerProps web.ServerProperties
}
func register(di initDI, ) {
if di.SecRegistrar != nil {
configurer := newFormLoginConfigurer(di.ServerProps)
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package formlogin
import (
"context"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/template"
"strings"
)
const (
LoginModelKeyUsernameParam = "usernameParam"
LoginModelKeyPasswordParam = "passwordParam"
LoginModelKeyLoginProcessUrl = "loginProcessUrl"
LoginModelKeyRememberedUsername = "rememberedUsername"
LoginModelKeyOtpParam = "otpParam"
LoginModelKeyMfaVerifyUrl = "mfaVerifyUrl"
LoginModelKeyMfaRefreshUrl = "mfaRefreshUrl"
LoginModelKeyMsxVersion = "MSXVersion"
)
type DefaultFormLoginController struct {
buildInfoResolver bootstrap.BuildInfoResolver
loginTemplate string
loginProcessUrl string
usernameParam string
passwordParam string
mfaTemplate string
mfaVerifyUrl string
mfaRefreshUrl string
otpParam string
}
type PageOptionsFunc func(*DefaultFormLoginPageOptions)
type DefaultFormLoginPageOptions struct {
BuildInfoResolver bootstrap.BuildInfoResolver
LoginTemplate string
UsernameParam string
PasswordParam string
LoginProcessUrl string
MfaTemplate string
OtpParam string
MfaVerifyUrl string
MfaRefreshUrl string
}
func NewDefaultLoginFormController(options ...PageOptionsFunc) *DefaultFormLoginController {
opts := DefaultFormLoginPageOptions{}
for _, f := range options {
f(&opts)
}
return &DefaultFormLoginController{
buildInfoResolver: opts.BuildInfoResolver,
loginTemplate: opts.LoginTemplate,
loginProcessUrl: opts.LoginProcessUrl,
usernameParam: opts.UsernameParam,
passwordParam: opts.PasswordParam,
mfaTemplate: opts.MfaTemplate,
mfaVerifyUrl: opts.MfaVerifyUrl,
mfaRefreshUrl: opts.MfaRefreshUrl,
otpParam: opts.OtpParam,
}
}
type LoginRequest struct {
Error bool `form:"error"`
}
type OTPVerificationRequest struct {
Error bool `form:"error"`
}
func (c *DefaultFormLoginController) Mappings() []web.Mapping {
return []web.Mapping{
template.New().Get("/login").HandlerFunc(c.LoginForm).Build(),
template.New().Get("/login/mfa").HandlerFunc(c.OtpVerificationForm).Build(),
}
}
func (c *DefaultFormLoginController) LoginForm(ctx context.Context, r *LoginRequest) (*template.ModelView, error) {
model := template.Model{
LoginModelKeyUsernameParam: c.usernameParam,
LoginModelKeyPasswordParam: c.passwordParam,
LoginModelKeyLoginProcessUrl: c.loginProcessUrl,
LoginModelKeyMsxVersion: c.msxVersion(),
}
s := session.Get(ctx)
if s != nil {
if err, errOk := s.Flash(redirect.FlashKeyPreviousError).(error); errOk && r.Error {
model[template.ModelKeyError] = err
}
if username, usernameOk := s.Flash(c.usernameParam).(string); usernameOk {
model[c.usernameParam] = username
}
}
if gc := web.GinContext(ctx); gc != nil {
if remembered, e := gc.Cookie(CookieKeyRememberedUsername); e == nil && remembered != "" {
model[LoginModelKeyRememberedUsername] = remembered
}
}
return &template.ModelView{
View: c.loginTemplate,
Model: model,
}, nil
}
func (c *DefaultFormLoginController) OtpVerificationForm(ctx context.Context, r *OTPVerificationRequest) (*template.ModelView, error) {
model := template.Model{
LoginModelKeyOtpParam: c.otpParam,
LoginModelKeyMfaVerifyUrl: c.mfaVerifyUrl,
LoginModelKeyMfaRefreshUrl: c.mfaRefreshUrl,
LoginModelKeyMsxVersion: c.msxVersion(),
}
s := session.Get(ctx)
if s != nil {
if err, errOk := s.Flash(redirect.FlashKeyPreviousError).(error); errOk && r.Error {
model[template.ModelKeyError] = err
}
}
return &template.ModelView{
View: c.mfaTemplate,
Model: model,
}, nil
}
func (c *DefaultFormLoginController) msxVersion() string {
if c.buildInfoResolver != nil {
return c.buildInfoResolver.Resolve().Version
}
if strings.ToLower(bootstrap.BuildVersion) == "unknown" {
return ""
}
return bootstrap.BuildVersion
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package formlogin
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"math"
"net/http"
"net/url"
"time"
)
const (
detailsKeyShouldRememberUsername = "RememberUsername"
)
type rememberMeOptions func(h *RememberUsernameSuccessHandler)
type RememberUsernameSuccessHandler struct {
contextPath string
rememberParam string
cookieDomain string
cookieSecured bool
cookieHttpOnly bool
cookieMaxAge time.Duration
}
func newRememberUsernameSuccessHandler(opts ...rememberMeOptions) *RememberUsernameSuccessHandler {
h := RememberUsernameSuccessHandler{
contextPath: "/",
}
for _, fn := range opts {
fn(&h)
}
return &h
}
func (h *RememberUsernameSuccessHandler) HandleAuthenticationSuccess(c context.Context, r *http.Request, rw http.ResponseWriter, _, to security.Authentication) {
details, ok := to.Details().(map[string]interface{})
if !ok {
details = map[string]interface{}{}
}
// set remember-me decision to auth's details if request has such parameter
remember := r.PostForm.Get(h.rememberParam)
if remember != "" {
details[detailsKeyShouldRememberUsername] = true
}
// auth process not finished yet, bail
if to.State() < security.StateAuthenticated {
return
}
// read remember-me decision from auth
if doRemember, ok := details[detailsKeyShouldRememberUsername].(bool); !ok || !doRemember {
// cleanup session
h.clear(c, rw)
return
}
// remember username
switch to.Principal().(type) {
case security.Account:
h.save(to.Principal().(security.Account).Username(), c, rw)
case string:
h.save(to.Principal().(string), c, rw)
}
}
func (h *RememberUsernameSuccessHandler) save(username string, _ context.Context, rw http.ResponseWriter) {
cookie := h.newCookie(CookieKeyRememberedUsername, username, h.cookieMaxAge)
http.SetCookie(rw, cookie)
}
func (h *RememberUsernameSuccessHandler) clear(_ context.Context, rw http.ResponseWriter) {
cookie := h.newCookie(CookieKeyRememberedUsername, "", -1)
http.SetCookie(rw, cookie)
}
func (h *RememberUsernameSuccessHandler) newCookie(name, value string, maxAge time.Duration) *http.Cookie {
cookie := &http.Cookie{
Name: name,
Value: url.QueryEscape(value),
Path: h.contextPath,
Domain: h.cookieDomain,
MaxAge: int(math.Round(maxAge.Seconds())),
Expires: calculateCookieExpires(maxAge),
Secure: h.cookieSecured,
HttpOnly: h.cookieHttpOnly,
SameSite: http.SameSiteStrictMode, //The remember me cookie should not be used cross site (unlike the session cookie which need to work cross site for sso)
}
return cookie
}
func calculateCookieExpires(maxAge time.Duration) time.Time {
if maxAge == 0 {
return time.Time{}
}
return time.Now().Add(maxAge)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package idp
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
util_matcher "github.com/cisco-open/go-lanai/pkg/utils/matcher"
netutil "github.com/cisco-open/go-lanai/pkg/utils/net"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"net/http"
)
var logger = log.New("SEC.IDP")
const (
InternalIdpForm = AuthenticationFlow("InternalIdpForm")
ExternalIdpSAML = AuthenticationFlow("ExternalIdpSAML")
UnknownIdp = AuthenticationFlow("UnKnown")
)
type AuthenticationFlow string
// MarshalText implements encoding.TextMarshaler
func (f AuthenticationFlow) MarshalText() ([]byte, error) {
return []byte(f), nil
}
// UnmarshalText implements encoding.TextUnmarshaler
func (f *AuthenticationFlow) UnmarshalText(data []byte) error {
value := string(data)
switch value {
case string(InternalIdpForm):
*f = InternalIdpForm
case string(ExternalIdpSAML):
*f = ExternalIdpSAML
default:
return fmt.Errorf("unrecognized authentication flow: %s", value)
}
return nil
}
type IdentityProvider interface {
Domain() string
}
type AuthenticationFlowAware interface {
AuthenticationFlow() AuthenticationFlow
}
type IdentityProviderManager interface {
GetIdentityProvidersWithFlow(ctx context.Context, flow AuthenticationFlow) []IdentityProvider
GetIdentityProviderByDomain(ctx context.Context, domain string) (IdentityProvider, error)
}
func RequestWithAuthenticationFlow(flow AuthenticationFlow, idpManager IdentityProviderManager) web.RequestMatcher {
matchableError := func() (interface{}, error) {
return string(UnknownIdp), nil
}
matchable := func(ctx context.Context, request *http.Request) (interface{}, error) {
var host = netutil.GetForwardedHostName(request)
idp, err := idpManager.GetIdentityProviderByDomain(ctx, host)
if err != nil {
logger.WithContext(ctx).Debugf("cannot find idp for domain %s", host)
return matchableError()
}
fa, ok := idp.(AuthenticationFlowAware)
if !ok {
return matchableError()
}
return string(fa.AuthenticationFlow()), nil
}
return matcher.CustomMatcher(fmt.Sprintf("IDP with [%s]", flow),
matchable,
util_matcher.WithString(string(flow), true))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package extsamlidp
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/config/authserver"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/idp"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
samlsp "github.com/cisco-open/go-lanai/pkg/security/saml/sp"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
)
type Options func(opt *option)
type option struct {
Properties *SamlAuthProperties
}
func WithProperties(props *SamlAuthProperties) Options {
return func(opt *option) {
opt.Properties = props
}
}
// SamlIdpSecurityConfigurer implements authserver.IdpSecurityConfigurer
//goland:noinspection GoNameStartsWithPackageName
type SamlIdpSecurityConfigurer struct {
props *SamlAuthProperties
}
func NewSamlIdpSecurityConfigurer(opts ...Options) *SamlIdpSecurityConfigurer {
opt := option{
Properties: NewSamlAuthProperties(),
}
for _, fn := range opts {
fn(&opt)
}
return &SamlIdpSecurityConfigurer{
props: opt.Properties,
}
}
func (c *SamlIdpSecurityConfigurer) Configure(ws security.WebSecurity, config *authserver.Configuration) {
// For Authorize endpoint
condition := idp.RequestWithAuthenticationFlow(idp.ExternalIdpSAML, config.IdpManager)
ws = ws.AndCondition(condition)
if !c.props.Enabled {
return
}
handler := redirect.NewRedirectWithURL(config.Endpoints.Error)
ws.
With(samlsp.New().
Issuer(config.Issuer).
ErrorPath(config.Endpoints.Error),
).
With(session.New().SettingService(config.SessionSettingService)).
With(access.New().
Request(matcher.AnyRequest()).Authenticated(),
).
With(errorhandling.New().
AccessDeniedHandler(handler),
)
}
func (c *SamlIdpSecurityConfigurer) ConfigureLogout(ws security.WebSecurity, config *authserver.Configuration) {
if !c.props.Enabled {
return
}
ws.With(samlsp.NewLogout().
Issuer(config.Issuer).
ErrorPath(config.Endpoints.Error),
)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package extsamlidp
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/idp"
)
type SamlIdpAutoCreateUserDetails struct {
Enabled bool
EmailWhiteList []string
AttributeMapping map[string]string
ElevatedUserRoleNames []string
RegularUserRoleNames []string
}
func (a SamlIdpAutoCreateUserDetails) GetElevatedUserRoleNames() []string {
return a.ElevatedUserRoleNames
}
func (a SamlIdpAutoCreateUserDetails) GetRegularUserRoleNames() []string {
return a.RegularUserRoleNames
}
func (a SamlIdpAutoCreateUserDetails) IsEnabled() bool {
return a.Enabled
}
func (a SamlIdpAutoCreateUserDetails) GetEmailWhiteList() []string {
return a.EmailWhiteList
}
func (a SamlIdpAutoCreateUserDetails) GetAttributeMapping() map[string]string {
return a.AttributeMapping
}
type SamlIdpDetails struct {
EntityId string
Domain string
MetadataLocation string
ExternalIdName string
ExternalIdpName string
MetadataRequireSignature bool
MetadataTrustCheck bool
MetadataTrustedKeys []string
AutoCreateUserDetails SamlIdpAutoCreateUserDetails
}
type SamlIdpOptions func(opt *SamlIdpDetails)
// SamlIdentityProvider implements idp.IdentityProvider, idp.AuthenticationFlowAware and samllogin.SamlIdentityProvider
type SamlIdentityProvider struct {
SamlIdpDetails
}
func (s SamlIdentityProvider) ShouldMetadataRequireSignature() bool {
return s.MetadataRequireSignature
}
func (s SamlIdentityProvider) ShouldMetadataTrustCheck() bool {
return s.MetadataTrustCheck
}
func (s SamlIdentityProvider) GetMetadataTrustedKeys() []string {
return s.MetadataTrustedKeys
}
func NewIdentityProvider(opts ...SamlIdpOptions) *SamlIdentityProvider {
opt := SamlIdpDetails{}
for _, f := range opts {
f(&opt)
}
return &SamlIdentityProvider{
SamlIdpDetails: opt,
}
}
func (s SamlIdentityProvider) AuthenticationFlow() idp.AuthenticationFlow {
return idp.ExternalIdpSAML
}
func (s SamlIdentityProvider) Domain() string {
return s.SamlIdpDetails.Domain
}
func (s SamlIdentityProvider) EntityId() string {
return s.SamlIdpDetails.EntityId
}
func (s SamlIdentityProvider) MetadataLocation() string {
return s.SamlIdpDetails.MetadataLocation
}
func (s SamlIdentityProvider) ExternalIdName() string {
return s.SamlIdpDetails.ExternalIdName
}
func (s SamlIdentityProvider) ExternalIdpName() string {
return s.SamlIdpDetails.ExternalIdpName
}
func (s SamlIdentityProvider) GetAutoCreateUserDetails() security.AutoCreateUserDetails {
return s.SamlIdpDetails.AutoCreateUserDetails
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package extsamlidp
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
samlsp "github.com/cisco-open/go-lanai/pkg/security/saml/sp"
"go.uber.org/fx"
)
//var logger = log.New("SEC.SAML")
var Module = &bootstrap.Module{
Name: "SAML IDP",
Precedence: security.MaxSecurityPrecedence - 100,
Options: []fx.Option {
fx.Provide(BindSamlAuthProperties),
},
}
func Use() {
samlsp.Use() // samllogin enables External SAML IDP support
bootstrap.Register(Module)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package extsamlidp
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/pkg/errors"
)
const (
PropertiesPrefix = "security.idp.saml"
)
type SamlAuthProperties struct {
Enabled bool `json:"enabled"`
Endpoints SamlAuthEndpointProperties `json:"endpoints"`
}
type SamlAuthEndpointProperties struct {}
func NewSamlAuthProperties() *SamlAuthProperties {
return &SamlAuthProperties{}
}
func BindSamlAuthProperties(ctx *bootstrap.ApplicationContext) SamlAuthProperties {
props := NewSamlAuthProperties()
if err := ctx.Config().Bind(props, PropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind SamlAuthProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwdidp
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/config/authserver"
"github.com/cisco-open/go-lanai/pkg/security/csrf"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/formlogin"
"github.com/cisco-open/go-lanai/pkg/security/idp"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/security/request_cache"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"time"
)
type Options func(opt *option)
type option struct {
Properties *PwdAuthProperties
MFAListeners []passwd.MFAEventListenerFunc
}
func WithProperties(props *PwdAuthProperties) Options {
return func(opt *option) {
opt.Properties = props
}
}
func WithMFAListeners(listeners ...passwd.MFAEventListenerFunc) Options {
return func(opt *option) {
opt.MFAListeners = append(opt.MFAListeners, listeners...)
}
}
// PasswordIdpSecurityConfigurer implements authserver.IdpSecurityConfigurer
type PasswordIdpSecurityConfigurer struct {
props *PwdAuthProperties
mfaListeners []passwd.MFAEventListenerFunc
}
func NewPasswordIdpSecurityConfigurer(opts ...Options) *PasswordIdpSecurityConfigurer {
opt := option{
Properties: NewPwdAuthProperties(),
MFAListeners: []passwd.MFAEventListenerFunc{},
}
for _, fn := range opts {
fn(&opt)
}
return &PasswordIdpSecurityConfigurer{
props: opt.Properties,
mfaListeners: opt.MFAListeners,
}
}
func (c *PasswordIdpSecurityConfigurer) Configure(ws security.WebSecurity, config *authserver.Configuration) {
// For Authorize endpoint
condition := idp.RequestWithAuthenticationFlow(idp.InternalIdpForm, config.IdpManager)
ws = ws.AndCondition(condition)
if !c.props.Enabled {
return
}
// Note: reset password url is not supported by whitelabel login form, and is hardcoded in MSX UI
handler := redirect.NewRedirectWithURL(config.Endpoints.Error)
ws.
With(session.New().SettingService(config.SessionSettingService)).
With(access.New().
Request(matcher.AnyRequest()).Authenticated(),
).
With(passwd.New().
MFA(c.props.MFA.Enabled).
OtpTTL(time.Duration(c.props.MFA.OtpTTL)).
PasswordEncoder(config.UserPasswordEncoder).
OtpVerifyLimit(c.props.MFA.OtpMaxAttempts).
OtpRefreshLimit(c.props.MFA.OtpResendLimit).
OtpLength(c.props.MFA.OtpLength).
OtpSecretSize(c.props.MFA.OtpSecretSize).
MFAEventListeners(c.mfaListeners...),
).
With(formlogin.New().
EnableMFA().
LoginUrl(c.props.Endpoints.FormLogin).
LoginProcessUrl(c.props.Endpoints.FormLoginProcess).
LoginErrorUrl(c.props.Endpoints.FormLoginError).
MfaUrl(c.props.Endpoints.OtpVerify).
MfaVerifyUrl(c.props.Endpoints.OtpVerifyProcess).
MfaRefreshUrl(c.props.Endpoints.OtpVerifyResend).
MfaErrorUrl(c.props.Endpoints.OtpVerifyError).
RememberCookieSecured(c.props.RememberMe.UseSecureCookie).
RememberCookieDomain(c.props.RememberMe.CookieDomain).
RememberCookieValidity(time.Duration(c.props.RememberMe.CookieValidity)),
).
With(errorhandling.New().
AccessDeniedHandler(handler),
).
With(csrf.New().
IgnoreCsrfProtectionMatcher(matcher.RequestWithPattern(config.Endpoints.Authorize.Location.Path)),
).
With(request_cache.New())
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwdidp
import (
"github.com/cisco-open/go-lanai/pkg/security/formlogin"
"github.com/cisco-open/go-lanai/pkg/web"
)
func NewWhiteLabelLoginFormController() web.Controller {
return formlogin.NewDefaultLoginFormController(func(opts *formlogin.DefaultFormLoginPageOptions) {
opts.LoginTemplate = "login.tmpl"
opts.LoginProcessUrl = "/login"
opts.UsernameParam = "username"
opts.PasswordParam = "password"
opts.MfaTemplate = "otp_verify.tmpl"
opts.MfaVerifyUrl = "/login/mfa"
opts.MfaRefreshUrl = "/login/mfa/refresh"
opts.OtpParam = "otp"
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwdidp
import "github.com/cisco-open/go-lanai/pkg/security/idp"
type PasswdIdpDetails struct {
Domain string
}
type PasswdIdpOptions func(opt *PasswdIdpDetails)
// PasswdIdentityProvider implements idp.IdentityProvider and idp.AuthenticationFlowAware
type PasswdIdentityProvider struct {
PasswdIdpDetails
}
func NewIdentityProvider(opts ...PasswdIdpOptions) *PasswdIdentityProvider {
opt := PasswdIdpDetails{}
for _, f := range opts {
f(&opt)
}
return &PasswdIdentityProvider{
PasswdIdpDetails: opt,
}
}
func (s PasswdIdentityProvider) AuthenticationFlow() idp.AuthenticationFlow {
return idp.InternalIdpForm
}
func (s PasswdIdentityProvider) Domain() string {
return s.PasswdIdpDetails.Domain
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwdidp
import (
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
)
//var logger = log.New("SEC.Passwd")
const (
OrderWhiteLabelTemplateFS = 0
OrderTemplateFSOverwrite = OrderWhiteLabelTemplateFS - 1000
)
//go:embed web/whitelabel/*
var whiteLabelContent embed.FS
//go:embed defaults-passwd-auth.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "password IDP",
Precedence: security.MaxSecurityPrecedence - 100,
Options: []fx.Option {
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(BindPwdAuthProperties),
fx.Invoke(register),
},
}
func Use() {
bootstrap.Register(Module)
}
func register(r *web.Registrar) {
r.MustRegister(web.OrderedFS(whiteLabelContent, OrderWhiteLabelTemplateFS))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwdidp
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"time"
)
const (
PwdAuthPropertiesPrefix = "security.idp.internal"
)
type PwdAuthProperties struct {
Enabled bool `json:"enabled"`
Domain string `json:"domain"`
SessionExpiredRedirectUrl string `json:"session-expired-redirect-url"`
Endpoints PwdAuthEndpointProperties `json:"endpoints"`
MFA PwdAuthMfaProperties `json:"mfa"`
RememberMe RememberMeProperties `json:"remember-me"`
}
type PwdAuthEndpointProperties struct {
FormLogin string `json:"form-login"`
FormLoginProcess string `json:"form-login-process"`
FormLoginError string `json:"form-login-error"`
OtpVerify string `json:"otp-verify"`
OtpVerifyProcess string `json:"otp-verify-process"`
OtpVerifyResend string `json:"otp-verify-resend"`
OtpVerifyError string `json:"otp-verify-error"`
ResetPasswordPageUrl string `json:"reset-password-page-url"`
}
type PwdAuthMfaProperties struct {
Enabled bool `json:"enabled"`
OtpLength uint `json:"otp-length"`
OtpSecretSize uint `json:"otp-secret-size"`
OtpTTL utils.Duration `json:"otp-ttl"`
OtpMaxAttempts uint `json:"otp-max-attempts"`
OtpResendLimit uint `json:"otp-resend-limit"`
}
type RememberMeProperties struct {
CookieDomain string `json:"cookie-domain"`
UseSecureCookie bool `json:"use-secure-cookie"`
CookieValidity utils.Duration `json:"cookie-validity"`
}
func NewPwdAuthProperties() *PwdAuthProperties {
return &PwdAuthProperties{
Domain: "localhost",
Endpoints: PwdAuthEndpointProperties{
FormLogin: "/login",
FormLoginProcess: "/login",
FormLoginError: "/login?error=true",
OtpVerify: "/login/mfa",
OtpVerifyProcess: "/login/mfa",
OtpVerifyResend: "/login/mfa/refresh",
OtpVerifyError: "/login/mfa?error=true",
ResetPasswordPageUrl: "http://localhost:9003/#/forgotpassword",
},
MFA: PwdAuthMfaProperties{
Enabled: true,
OtpLength: 6,
OtpSecretSize: 10,
OtpTTL: utils.Duration(5 * time.Minute),
OtpMaxAttempts: 5,
OtpResendLimit: 5,
},
RememberMe: RememberMeProperties{
CookieValidity: utils.Duration(2 * 7 * 24 * 60 * time.Minute),
},
}
}
func BindPwdAuthProperties(ctx *bootstrap.ApplicationContext) PwdAuthProperties {
props := NewPwdAuthProperties()
if err := ctx.Config().Bind(props, PwdAuthPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind PwdAuthProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package unknownIdp
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/config/authserver"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/idp"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
)
type NoIdpSecurityConfigurer struct {
}
func NewNoIdpSecurityConfigurer() *NoIdpSecurityConfigurer {
return &NoIdpSecurityConfigurer{}
}
func (c *NoIdpSecurityConfigurer) Configure(ws security.WebSecurity, config *authserver.Configuration) {
// For Authorize endpoint
handler := redirect.NewRedirectWithURL(config.Endpoints.Error)
condition := idp.RequestWithAuthenticationFlow(idp.UnknownIdp, config.IdpManager)
ws.AndCondition(condition).
With(session.New().SettingService(config.SessionSettingService)).
With(access.New().
Request(matcher.AnyRequest()).
AllowIf(authenticatedWithMessage("Identity provider is not configured for this sub-domain")),
).
With(errorhandling.New().
AccessDeniedHandler(handler),
)
}
func authenticatedWithMessage(format string, v...interface{}) access.ControlFunc {
return func(auth security.Authentication) (decision bool, reason error) {
if auth.State() >= security.StateAuthenticated {
return true, nil
} else {
return false, security.NewAccessDeniedError(fmt.Sprintf(format, v...))
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
"reflect"
"sort"
"sync"
)
/************************************
Security Initialization
*************************************/
type initializer struct {
initialized bool
initializing bool
featureConfigurers map[FeatureIdentifier]FeatureConfigurer
configurers []Configurer
globalAuthenticator Authenticator
}
var initializeMutex sync.Mutex
func newSecurity(globalAuth Authenticator) *initializer {
return &initializer{
featureConfigurers: map[FeatureIdentifier]FeatureConfigurer{},
configurers: []Configurer{},
globalAuthenticator: globalAuth,
}
}
// Register is not threadsafe, usually called in "fx.Invoke" or "fx.Provide"
func (init *initializer) Register(configurers ...Configurer) {
if err := init.validateState("register security.Configurer"); err != nil {
panic(err)
}
init.configurers = append(init.configurers, configurers...)
}
// RegisterFeature is not threadsafe, usually called in "fx.Invoke" or "fx.Provide"
func (init *initializer) RegisterFeature(featureId FeatureIdentifier, featureConfigurer FeatureConfigurer) {
if err := init.validateState("register security.FeatureConfigurer"); err != nil {
panic(err)
}
init.featureConfigurers[featureId] = featureConfigurer
}
// FindFeature is not threadsafe, usually called in "fx.Invoke" or "fx.Provide"
func (init *initializer) FindFeature(featureId FeatureIdentifier) FeatureConfigurer {
f, _ := init.featureConfigurers[featureId]
return f
}
func (init *initializer) validateState(action string) error {
switch {
case init.initialized:
return fmt.Errorf("cannot %s: security already initialized", action)
case init.initializing:
return fmt.Errorf("cannot %s: security already started initializing", action)
default:
return nil
}
}
func (init *initializer) Initialize(ctx context.Context, _ fx.Lifecycle, registrar *web.Registrar) error {
initializeMutex.Lock()
defer initializeMutex.Unlock()
if init.initialized || init.initializing {
return fmt.Errorf("security.Initializer.initialize cannot be called twice")
}
init.initializing = true
// sort configurer
sort.Slice(init.configurers, func(i,j int) bool {
return order.OrderedFirstCompare(init.configurers[i], init.configurers[j])
})
mergedRequestPreProcessors := make(map[web.RequestPreProcessorName]web.RequestPreProcessor)
// go through each configurer
for _,configurer := range init.configurers {
builder, requestPreProcessors, err := init.build(ctx, configurer)
if err != nil {
return err
}
for k, v := range requestPreProcessors {
if _, ok := mergedRequestPreProcessors[k]; !ok {
mergedRequestPreProcessors[k] = v
}
}
mappings := builder.Build()
// register web.Mapping if possible
if registrar == nil {
continue
}
for _,mapping := range mappings {
if err := registrar.Register(mapping); err != nil {
return err
}
// Do some logging
logMapping(ctx, mapping)
}
}
if registrar != nil {
for _, v := range mergedRequestPreProcessors {
registrar.MustRegister(v)
}
}
init.initialized = true
init.initializing = false
return nil
}
func (init *initializer) build(ctx context.Context, configurer Configurer) (WebSecurityMappingBuilder, map[web.RequestPreProcessorName]web.RequestPreProcessor, error) {
// collect security configs
ws := newWebSecurity(ctx, NewAuthenticator(), map[string]interface{}{
WSSharedKeyCompositeAuthSuccessHandler: NewAuthenticationSuccessHandler(),
WSSharedKeyCompositeAuthErrorHandler: NewAuthenticationErrorHandler(),
WSSharedKeyCompositeAccessDeniedHandler: NewAccessDeniedHandler(),
})
configurer.Configure(ws)
// configure web security
// Note: We want to allow feature's configurer to add/remove other features while in the iteration.
// Adding/removing features that have lower order than the current feature should panic
// Doing so would result in performance reduction on iteration. But it's small price we are willing to pay
sortFeatures(ws.Features())
for i := 0; i < len(ws.Features()); i++ {
f := ws.Features()[i]
// get corresponding feature configurer
fc, ok := init.featureConfigurers[f.Identifier()]
if !ok {
return nil, nil, fmt.Errorf("unable to build security feature %T: no FeatureConfigurer found", f)
}
// mark/reset some flags
ws.applied[f.Identifier()] = struct{}{}
ws.featuresChanged = false
// apply configurer
if err := fc.Apply(f, ws); err != nil {
return nil, nil, fmt.Errorf("error during process WebSecurity [%v]: %v", ws, err)
}
// the applied configurer may have enabled more features. (ws.Features() is different)
if !ws.featuresChanged {
continue
}
// handle feature change
if err := init.handleFeaturesChanged(i, f, ws.Features()); err != nil {
return nil, nil, fmt.Errorf("error during process WebSecurity [%v]: %v", ws, err)
}
}
if err := init.process(ws); err != nil {
return nil, nil, err
}
var processors map[web.RequestPreProcessorName]web.RequestPreProcessor = nil
if ws.Shared(WSSharedKeyRequestPreProcessors) != nil {
processors = ws.Shared(WSSharedKeyRequestPreProcessors).(map[web.RequestPreProcessorName]web.RequestPreProcessor)
}
return ws, processors, nil
}
// handleFeaturesChanged is invoked if feature list changed during iteration.
// we need to
// 1. check if current feature's index didn't change (in case elements before current were removed)
// 2. re-sort the remaining (un-processed) features from current index
// 3. check if any remaining features (likely newly added) has lower order than current
func (init *initializer) handleFeaturesChanged(i int, f Feature, features []Feature) error {
if i >= len(features) - 1 {
// last one, nothing is needed
return nil
}
// step 1
if features[i] != f {
return fmt.Errorf("feature configurer for [%v] attempted to disable already applied features", f.Identifier())
}
// step 2
sortFeatures(features[i+1:])
// step 3
next := features[i+1]
if featureOrderLess(next, f) {
return fmt.Errorf("feature configurer for [%v] attempted to enable feature [%v] which won't be processed", f.Identifier(), next.Identifier())
}
return nil
}
func (init *initializer) process(ws *webSecurity) error {
if len(ws.handlers) == 0 {
return fmt.Errorf("no middleware were configuered for WebSecurity %v", ws)
}
switch {
case !hasConcreteAuthenticator(ws.authenticator) && !hasConcreteAuthenticator(init.globalAuthenticator):
//return fmt.Errorf("no concrete authenticator is configured for WebSecurity %v, and global authenticator is not configurered neither", ws)
ws.authenticator.(*CompositeAuthenticator).Add(&AnonymousAuthenticator{})
case !hasConcreteAuthenticator(ws.authenticator):
// ws has no concrete authenticator, but global authenticator is configured, use it
if _,ok := init.globalAuthenticator.(*CompositeAuthenticator); ok {
ws.authenticator.(*CompositeAuthenticator).Merge(init.globalAuthenticator.(*CompositeAuthenticator))
} else {
ws.authenticator.(*CompositeAuthenticator).Add(init.globalAuthenticator)
}
}
return nil
}
func hasConcreteAuthenticator(auth Authenticator) bool {
if auth == nil {
return false
}
composite, ok := auth.(*CompositeAuthenticator)
return !ok || len(composite.authenticators) != 0
}
func sortFeatures(features []Feature) {
sort.Slice(features, func(i,j int) bool {
return featureOrderLess(features[i], features[j])
})
}
func featureOrderLess(l Feature, r Feature) bool {
return order.OrderedFirstCompare(l.Identifier(), r.Identifier())
}
func logMapping(ctx context.Context, mapping web.Mapping) {
switch mapping.(type) {
case web.MiddlewareMapping:
mw := mapping.(web.MiddlewareMapping)
logger.WithContext(ctx).Infof("registered security middleware [%d] [%s] %s -> %v",
mw.Order(), mw.Name(), log.Capped(mw.Matcher(), 80), reflect.ValueOf(mw.HandlerFunc()).String())
case web.MvcMapping:
m := mapping.(web.MvcMapping)
logger.WithContext(ctx).Infof("registered security MVC mapping [%s %s] [%s] %s -> endpoint",
m.Method(), m.Path(), m.Name(), log.Capped(m.Condition(), 80))
case web.SimpleMapping:
m := mapping.(web.SimpleMapping)
logger.WithContext(ctx).Infof("registered security simple mapping [%s %s] [%s] %s -> %v",
m.Method(), m.Path(), m.Name(), log.Capped(m.Condition(), 80), reflect.ValueOf(m.HandlerFunc()).String())
default:
logger.WithContext(ctx).Infof("registered security mapping [%s]: %v", mapping.Name(), mapping)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"net/http"
)
/***************************************
Additional Context for Internal
****************************************/
// FeatureModifier add or remove features. \
// Should not used directly by service
// use corresponding feature's Configure(WebSecurity) instead
type FeatureModifier interface {
// Enable kick off configuration of give Feature.
// If the given Feature is not enabled yet, it's added to the receiver and returned
// If the given Feature is already enabled, the already enabled Feature is returned
Enable(Feature) Feature
// Disable remove given feature using its FeatureIdentifier
Disable(Feature)
}
type WebSecurityReader interface {
GetRoute() web.RouteMatcher
GetCondition() web.RequestMatcher
GetHandlers() []interface{}
}
type WebSecurityMappingBuilder interface {
Build() []web.Mapping
}
// FeatureConfigurer not intended to be used directly in service
type FeatureConfigurer interface {
Apply(Feature, WebSecurity) error
}
type FeatureRegistrar interface {
// RegisterFeature is typically used by feature packages, such as session, oauth, etc
// not intended to be used directly in service
RegisterFeature(featureId FeatureIdentifier, featureConfigurer FeatureConfigurer)
// FindFeature is typically used by feature packages
FindFeature(featureId FeatureIdentifier) FeatureConfigurer
}
func NoopHandlerFunc() gin.HandlerFunc {
return func(c *gin.Context) {
_ = c.AbortWithError(http.StatusNotFound, fmt.Errorf("page not found"))
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"fmt"
"net/url"
pathutils "path"
"strings"
)
type UrlBuilderOptions func(opt *UrlBuilderOption)
type UrlBuilderOption struct {
FQDN string
Path string
}
type Issuer interface {
Protocol() string
Domain() string
Port() int
ContextPath() string
IsSecured() bool
// Identifier is the unique identifier of the deployed auth server
// Typeical implementation is to use base url of issuer's domain.
Identifier() string
// LevelOfAssurance construct level-of-assurance string with given string
// level-of-assurance represent how confident the auth issuer is about user's identity
// ref: https://developer.mobileconnect.io/level-of-assurance
LevelOfAssurance(level int) string
// BuildUrl build a URL with given url builder options
// Implementation specs:
// 1. if UrlBuilderOption.FQDN is not specified, Issuer.Domain() should be used
// 2. if UrlBuilderOption.FQDN is not a subdomain of Issuer.Domain(), an error should be returned
// 3. should assume UrlBuilderOption.Path doesn't includes Issuer.ContextPath and the generated URL always
// include Issuer.ContextPath
// 4. if UrlBuilderOption.Path is not specified, the generated URL could be used as a base URL
// 5. BuildUrl should not returns error when no options provided
BuildUrl(...UrlBuilderOptions) (*url.URL, error)
}
/***************************
Default Impl.
***************************/
type DefaultIssuerDetails struct {
Protocol string
Domain string
Port int
ContextPath string
IncludePort bool
}
type DefaultIssuer struct {
DefaultIssuerDetails
}
func NewIssuer(opts ...func(*DefaultIssuerDetails)) *DefaultIssuer {
opt := DefaultIssuerDetails{
}
for _, f := range opts {
f(&opt)
}
return &DefaultIssuer{
DefaultIssuerDetails: opt,
}
}
func (i DefaultIssuer) Protocol() string {
return i.DefaultIssuerDetails.Protocol
}
func (i DefaultIssuer) Domain() string {
return i.DefaultIssuerDetails.Domain
}
func (i DefaultIssuer) Port() int {
return i.DefaultIssuerDetails.Port
}
func (i DefaultIssuer) ContextPath() string {
return i.DefaultIssuerDetails.ContextPath
}
func (i DefaultIssuer) IsSecured() bool {
return strings.ToLower(i.DefaultIssuerDetails.Protocol) == "https"
}
func (i DefaultIssuer) Identifier() string {
id, _ := i.BuildUrl()
return id.String()
}
func (i DefaultIssuer) LevelOfAssurance(level int) string {
path := fmt.Sprintf("/loa-%d", level)
loa, _ := i.BuildUrl(func(opt *UrlBuilderOption) {
opt.Path = path
})
return loa.String()
}
func (i DefaultIssuer) BuildUrl(options ...UrlBuilderOptions) (*url.URL, error) {
opt := UrlBuilderOption{}
for _, f := range options {
f(&opt)
}
if opt.FQDN == "" {
opt.FQDN = i.DefaultIssuerDetails.Domain
}
if strings.HasSuffix(opt.FQDN, i.DefaultIssuerDetails.Domain) && strings.HasPrefix(opt.FQDN, ".") {
return nil, fmt.Errorf("invalid subdomain %s", opt.FQDN)
}
ret := &url.URL{}
ret.Scheme = i.DefaultIssuerDetails.Protocol
ret.Host = opt.FQDN
if i.DefaultIssuerDetails.IncludePort {
ret.Host = fmt.Sprintf("%s:%d", ret.Host, i.DefaultIssuerDetails.Port)
}
ret.Path = i.DefaultIssuerDetails.ContextPath
if opt.Path != "" {
path := pathutils.Join(ret.Path, opt.Path)
ret = ret.ResolveReference(&url.URL{Path: path})
}
return ret, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package logout
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"net/http"
)
var (
FeatureId = security.FeatureId("Logout", security.FeatureOrderLogout)
)
//goland:noinspection GoNameStartsWithPackageName
type LogoutConfigurer struct {
}
func newLogoutConfigurer() *LogoutConfigurer {
return &LogoutConfigurer{
}
}
func (c *LogoutConfigurer) Apply(feature security.Feature, ws security.WebSecurity) error {
// Verify
if err := c.validate(feature.(*LogoutFeature), ws); err != nil {
return err
}
f := feature.(*LogoutFeature)
supportedMethods := []string{
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodDelete,
}
// let ws know to intercept additional url
route := matcher.RouteWithPattern(f.logoutUrl, supportedMethods...)
ws.Route(route)
// configure middlewares
// Note: since this MW handles a new path, we add middleware as-is instead of a security.MiddlewareTemplate
order.SortStable(f.logoutHandlers, order.OrderedFirstCompare)
logout := NewLogoutMiddleware(
c.effectiveSuccessHandler(f, ws),
c.effectiveErrorHandler(f, ws),
c.effectiveEntryPoints(f),
f.logoutHandlers...)
mw := middleware.NewBuilder("logout").
ApplyTo(route).
Order(security.MWOrderFormLogout).
Use(logout.LogoutHandlerFunc())
ws.Add(mw)
// configure additional endpoint mappings to trigger middleware
for _,method := range supportedMethods {
endpoint := mapping.New("logout dummy " + method).
Method(method).Path(f.logoutUrl).
HandlerFunc(security.NoopHandlerFunc())
ws.Add(endpoint)
}
return nil
}
func (c *LogoutConfigurer) validate(f *LogoutFeature, _ security.WebSecurity) error {
if f.logoutUrl == "" {
return fmt.Errorf("logoutUrl is missing for logout")
}
if f.successUrl == "" && len(f.successHandlers) == 0 {
return fmt.Errorf("successUrl and successHandler are both missing for logout")
}
return nil
}
func (c *LogoutConfigurer) effectiveSuccessHandler(f *LogoutFeature, ws security.WebSecurity) security.AuthenticationSuccessHandler {
handlers := make([]security.AuthenticationSuccessHandler, len(f.successHandlers), len(f.successHandlers) + 2)
copy(handlers, f.successHandlers)
if len(handlers) == 0 {
handlers = append(handlers, redirect.NewRedirectWithURL(f.successUrl))
}
if globalHandler, ok := ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(security.AuthenticationSuccessHandler); ok {
handlers = append([]security.AuthenticationSuccessHandler{globalHandler}, handlers...) // global BEFORE logout success handlers
}
order.SortStable(handlers, order.OrderedFirstCompare)
return security.NewAuthenticationSuccessHandler(handlers...)
}
func (c *LogoutConfigurer) effectiveErrorHandler(f *LogoutFeature, ws security.WebSecurity) security.AuthenticationErrorHandler {
handlers := make([]security.AuthenticationErrorHandler, len(f.errorHandlers), len(f.errorHandlers) + 2)
copy(handlers, f.errorHandlers)
if len(handlers) == 0 {
handlers = append(handlers, redirect.NewRedirectWithURL(f.errorUrl))
}
if globalHandler, ok := ws.Shared(security.WSSharedKeyCompositeAuthErrorHandler).(security.AuthenticationErrorHandler); ok {
handlers = append(handlers, globalHandler) // global AFTER logout error handlers
}
return security.NewAuthenticationErrorHandler(handlers...)
}
func (c *LogoutConfigurer) effectiveEntryPoints(f *LogoutFeature) security.AuthenticationEntryPoint {
if len(f.entryPoints) == 0 {
return nil
}
order.SortStable(f.entryPoints, order.OrderedFirstCompare)
return multiEntryPoints(f.entryPoints)
}
type multiEntryPoints []security.AuthenticationEntryPoint
func (ep multiEntryPoints) Commence(ctx context.Context, request *http.Request, writer http.ResponseWriter, err error) {
for _, entryPoint := range ep {
entryPoint.Commence(ctx, request, writer, err)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package logout
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"net/http"
)
/*********************************
Feature Impl
*********************************/
type Warnings []error
//goland:noinspection GoNameStartsWithPackageName
type LogoutHandler interface {
// HandleLogout is the method MW would use to perform logging out actions.
// In case of multiple LogoutHandler are registered, implementing class can terminate logout by implementing ConditionalLogoutHandler
// If the returned error is security.ErrorSubTypeAuthWarning, the success handler is used with returned error added to the context
HandleLogout(context.Context, *http.Request, http.ResponseWriter, security.Authentication) error
}
// ConditionalLogoutHandler is a supplementary interface for LogoutHandler.
// It's capable of cancelling/delaying logout process before any LogoutHandler is executed.
// When non-nil error is returned and logout middleware is configured with an security.AuthenticationEntryPoint,
// the entry point is used to delay the logout process
// In case of multiple ConditionalLogoutHandler, returning error by any handler would immediately terminate the process
type ConditionalLogoutHandler interface {
// ShouldLogout returns error if logging out cannot be performed.
ShouldLogout(context.Context, *http.Request, http.ResponseWriter, security.Authentication) error
}
//goland:noinspection GoNameStartsWithPackageName
type LogoutFeature struct {
successHandlers []security.AuthenticationSuccessHandler
errorHandlers []security.AuthenticationErrorHandler
entryPoints []security.AuthenticationEntryPoint
successUrl string
errorUrl string
logoutHandlers []LogoutHandler
logoutUrl string
}
// Identifier Standard security.Feature entrypoint
func (f *LogoutFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
// LogoutHandlers override default handler
func (f *LogoutFeature) LogoutHandlers(logoutHandlers ...LogoutHandler) *LogoutFeature {
f.logoutHandlers = logoutHandlers
return f
}
func (f *LogoutFeature) AddLogoutHandler(logoutHandler LogoutHandler) *LogoutFeature {
f.logoutHandlers = append([]LogoutHandler{logoutHandler}, f.logoutHandlers...)
return f
}
func (f *LogoutFeature) LogoutUrl(logoutUrl string) *LogoutFeature {
f.logoutUrl = logoutUrl
return f
}
func (f *LogoutFeature) SuccessUrl(successUrl string) *LogoutFeature {
f.successUrl = successUrl
return f
}
func (f *LogoutFeature) ErrorUrl(errorUrl string) *LogoutFeature {
f.errorUrl = errorUrl
return f
}
// AddSuccessHandler overrides SuccessUrl
func (f *LogoutFeature) AddSuccessHandler(successHandler security.AuthenticationSuccessHandler) *LogoutFeature {
f.successHandlers = append(f.successHandlers, successHandler)
return f
}
// AddErrorHandler overrides ErrorUrl
func (f *LogoutFeature) AddErrorHandler(errorHandler security.AuthenticationErrorHandler) *LogoutFeature {
f.errorHandlers = append(f.errorHandlers, errorHandler)
return f
}
// AddEntryPoint is used when ConditionalLogoutHandler decide cancel/delay logout process
func (f *LogoutFeature) AddEntryPoint(entryPoint security.AuthenticationEntryPoint) *LogoutFeature {
f.entryPoints = append(f.entryPoints, entryPoint)
return f
}
/*********************************
Constructors and Configure
*********************************/
// Configure security.Feature entrypoint, used for modifying existing configuration in given security.WebSecurity
func Configure(ws security.WebSecurity) *LogoutFeature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*LogoutFeature)
}
panic(fmt.Errorf("unable to configure form login: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// New Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *LogoutFeature {
return &LogoutFeature{
successUrl: "/login",
logoutUrl: "/logout",
logoutHandlers: []LogoutHandler{
DefaultLogoutHandler{},
},
}
}
type DefaultLogoutHandler struct{}
func (h DefaultLogoutHandler) HandleLogout(ctx context.Context, _ *http.Request, _ http.ResponseWriter, _ security.Authentication) error {
security.MustClear(ctx)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package logout
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/gin-gonic/gin"
)
var ctxKeyWarnings = "logout.Warnings"
func GetWarnings(ctx context.Context) Warnings {
w, _ := ctx.Value(ctxKeyWarnings).(Warnings)
return w
}
//goland:noinspection GoNameStartsWithPackageName
type LogoutMiddleware struct {
successHandler security.AuthenticationSuccessHandler
errorHandler security.AuthenticationErrorHandler
entryPoint security.AuthenticationEntryPoint
logoutHandlers []LogoutHandler
conditionalHandlers []ConditionalLogoutHandler
}
func NewLogoutMiddleware(successHandler security.AuthenticationSuccessHandler,
errorHandler security.AuthenticationErrorHandler,
entryPoint security.AuthenticationEntryPoint,
logoutHandlers ...LogoutHandler) *LogoutMiddleware {
conditionalHandlers := make([]ConditionalLogoutHandler, 0, len(logoutHandlers))
for _, h := range logoutHandlers {
if conditional, ok := h.(ConditionalLogoutHandler); ok {
conditionalHandlers = append(conditionalHandlers, conditional)
}
}
return &LogoutMiddleware{
successHandler: successHandler,
errorHandler: errorHandler,
entryPoint: entryPoint,
logoutHandlers: logoutHandlers,
conditionalHandlers: conditionalHandlers,
}
}
func (mw *LogoutMiddleware) LogoutHandlerFunc() gin.HandlerFunc {
return func(gc *gin.Context) {
before := security.Get(gc)
// pre-logout check
for _, h := range mw.conditionalHandlers {
switch e := h.ShouldLogout(gc, gc.Request, gc.Writer, before); {
case e != nil && mw.entryPoint != nil:
mw.handleCancelled(gc, e)
return
case e != nil:
mw.handleError(gc, e)
return
}
}
// perform logout
for _, handler := range mw.logoutHandlers {
switch e := handler.HandleLogout(gc, gc.Request, gc.Writer, before); {
case errors.Is(e, security.ErrorSubTypeAuthWarning):
mw.handleWarnings(gc, e)
case e != nil:
mw.handleError(gc, e)
return
}
}
mw.handleSuccess(gc, before)
}
}
func (mw *LogoutMiddleware) handleSuccess(gc *gin.Context, before security.Authentication) {
mw.successHandler.HandleAuthenticationSuccess(gc, gc.Request, gc.Writer, before, security.Get(gc))
if gc.Writer.Written() {
gc.Abort()
}
}
func (mw *LogoutMiddleware) handleWarnings(gc *gin.Context, warning error) {
var warnings Warnings
existing := gc.Value(ctxKeyWarnings)
switch v := existing.(type) {
case Warnings:
warnings = append(v, warning)
case []error:
warnings = append(v, warning)
case nil:
warnings = Warnings{warning}
default:
warnings = Warnings{fmt.Errorf("%v", existing), warning}
}
gc.Set(ctxKeyWarnings, warnings)
}
func (mw *LogoutMiddleware) handleError(gc *gin.Context, err error) {
if !errors.Is(err, security.ErrorTypeSecurity) {
err = security.NewInternalAuthenticationError(err.Error(), err)
}
mw.errorHandler.HandleAuthenticationError(gc, gc.Request, gc.Writer, err)
if gc.Writer.Written() {
gc.Abort()
}
}
func (mw *LogoutMiddleware) handleCancelled(ctx *gin.Context, err error) {
mw.entryPoint.Commence(ctx, ctx.Request, ctx.Writer, err)
if ctx.Writer.Written() {
ctx.Abort()
}
}
// Code generated by mockery v2.14.0. DO NOT EDIT.
package mocks
import (
context "context"
http "net/http"
mock "github.com/stretchr/testify/mock"
security "github.com/cisco-open/go-lanai/pkg/security"
)
// ConditionalLogoutHandler is an autogenerated mock type for the ConditionalLogoutHandler type
type ConditionalLogoutHandler struct {
mock.Mock
}
// ShouldLogout provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *ConditionalLogoutHandler) ShouldLogout(_a0 context.Context, _a1 *http.Request, _a2 http.ResponseWriter, _a3 security.Authentication) error {
ret := _m.Called(_a0, _a1, _a2, _a3)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *http.Request, http.ResponseWriter, security.Authentication) error); ok {
r0 = rf(_a0, _a1, _a2, _a3)
} else {
r0 = ret.Error(0)
}
return r0
}
type mockConstructorTestingTNewConditionalLogoutHandler interface {
mock.TestingT
Cleanup(func())
}
// NewConditionalLogoutHandler creates a new instance of ConditionalLogoutHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewConditionalLogoutHandler(t mockConstructorTestingTNewConditionalLogoutHandler) *ConditionalLogoutHandler {
mock := &ConditionalLogoutHandler{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Code generated by mockery v2.14.0. DO NOT EDIT.
package mocks
import (
context "context"
http "net/http"
mock "github.com/stretchr/testify/mock"
security "github.com/cisco-open/go-lanai/pkg/security"
)
// LogoutHandler is an autogenerated mock type for the LogoutHandler type
type LogoutHandler struct {
mock.Mock
}
// HandleLogout provides a mock function with given fields: _a0, _a1, _a2, _a3
func (_m *LogoutHandler) HandleLogout(_a0 context.Context, _a1 *http.Request, _a2 http.ResponseWriter, _a3 security.Authentication) error {
ret := _m.Called(_a0, _a1, _a2, _a3)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *http.Request, http.ResponseWriter, security.Authentication) error); ok {
r0 = rf(_a0, _a1, _a2, _a3)
} else {
r0 = ret.Error(0)
}
return r0
}
type mockConstructorTestingTNewLogoutHandler interface {
mock.TestingT
Cleanup(func())
}
// NewLogoutHandler creates a new instance of LogoutHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewLogoutHandler(t mockConstructorTestingTNewLogoutHandler) *LogoutHandler {
mock := &LogoutHandler{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package logout
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "logout",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newLogoutConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
package auth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils"
)
type Approval struct {
UserId interface{}
Username string
ClientId string
RedirectUri string
Scopes utils.StringSet
}
type ApprovalLoadOptions func(*Approval)
type ApprovalStore interface {
SaveApproval(c context.Context, a *Approval) error
LoadApprovals(c context.Context, opts ...ApprovalLoadOptions) ([]*Approval, error)
}
func WithUserId(userId interface{}) ApprovalLoadOptions {
return func(a *Approval) {
a.UserId = userId
}
}
func WithUsername(username string) ApprovalLoadOptions {
return func(a *Approval) {
a.Username = username
}
}
func WithClientId(clientId string) ApprovalLoadOptions {
return func(a *Approval) {
a.ClientId = clientId
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authorize
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"net/http"
)
var (
FeatureId = security.FeatureId("OAuth2AuthorizeEndpoint", security.FeatureOrderOAuth2AuthorizeEndpoint)
)
//goland:noinspection GoNameStartsWithPackageName
type AuthorizeEndpointConfigurer struct {
}
func newOAuth2AuthorizeEndpointConfigurer() *AuthorizeEndpointConfigurer {
return &AuthorizeEndpointConfigurer{}
}
func (c *AuthorizeEndpointConfigurer) Apply(feature security.Feature, ws security.WebSecurity) (err error) {
// Verify
f := feature.(*AuthorizeFeature)
if err := c.validate(f, ws); err != nil {
return err
}
// configure other features
errorhandling.Configure(ws).
AdditionalErrorHandler(f.errorHandler)
//prepare middlewares
authRouteMatcher := matcher.RouteWithPattern(f.path, http.MethodGet, http.MethodPost)
approveRouteMatcher := matcher.RouteWithPattern(f.approvalPath, http.MethodPost)
approveRequestMatcher := matcher.RequestWithPattern(f.approvalPath, http.MethodPost).
And(matcher.RequestHasPostForm(oauth2.ParameterUserApproval))
authorizeMW := NewAuthorizeEndpointMiddleware(func(opts *AuthorizeMWOption) {
opts.RequestProcessor = f.requestProcessor
opts.AuthorizeHandler = f.authorizeHandler
opts.ApprovalMatcher = approveRequestMatcher
opts.ApprovalStore = f.approvalStore
})
// install middlewares
preAuth := middleware.NewBuilder("authorize validation").
ApplyTo(authRouteMatcher.Or(approveRouteMatcher)).
Order(security.MWOrderOAuth2AuthValidation).
Use(authorizeMW.PreAuthenticateHandlerFunc(f.condition))
ws.Add(preAuth)
// install authorize endpoint
epGet := mapping.Get(f.path).Name("authorize GET").
HandlerFunc(authorizeMW.AuthorizeHandlerFunc(f.condition))
epPost := mapping.Post(f.path).Name("authorize Post").
HandlerFunc(authorizeMW.AuthorizeHandlerFunc(f.condition))
ws.Route(authRouteMatcher).Add(epGet, epPost)
// install approve endpoint
approve := mapping.Post(f.approvalPath).Name("approve endpoint").
HandlerFunc(authorizeMW.ApproveOrDenyHandlerFunc())
ws.Route(approveRouteMatcher).Add(approve)
return nil
}
func (c *AuthorizeEndpointConfigurer) validate(f *AuthorizeFeature, ws security.WebSecurity) error {
if f.path == "" {
return fmt.Errorf("authorize endpoint path is not set")
}
if f.errorHandler == nil {
f.errorHandler = auth.NewOAuth2ErrorHandler()
}
if f.authorizeHandler == nil {
return fmt.Errorf("auhtorize handler is not set")
}
//if f.granters == nil || len(f.granters) == 0 {
// return fmt.Errorf("token granters is not set")
//}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authorize
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/web"
)
// AuthorizeFeature configures authorization endpoints
//
//goland:noinspection GoNameStartsWithPackageName
type AuthorizeFeature struct {
path string
condition web.RequestMatcher
approvalPath string
requestProcessor auth.AuthorizeRequestProcessor
authorizeHandler auth.AuthorizeHandler
errorHandler *auth.OAuth2ErrorHandler
approvalStore auth.ApprovalStore
}
func (f *AuthorizeFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
// Configure is standard security.Feature entrypoint
func Configure(ws security.WebSecurity) *AuthorizeFeature {
feature := NewEndpoint()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*AuthorizeFeature)
}
panic(fmt.Errorf("unable to configure oauth2 authserver: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// NewEndpoint is standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func NewEndpoint() *AuthorizeFeature {
return &AuthorizeFeature{}
}
/** Setters **/
func (f *AuthorizeFeature) Path(path string) *AuthorizeFeature {
f.path = path
return f
}
func (f *AuthorizeFeature) Condition(condition web.RequestMatcher) *AuthorizeFeature {
f.condition = condition
return f
}
func (f *AuthorizeFeature) ApprovalPath(approvalPath string) *AuthorizeFeature {
f.approvalPath = approvalPath
return f
}
func (f *AuthorizeFeature) RequestProcessors(processors ...auth.ChainedAuthorizeRequestProcessor) *AuthorizeFeature {
f.requestProcessor = auth.NewAuthorizeRequestProcessor(processors...)
return f
}
func (f *AuthorizeFeature) RequestProcessor(processor auth.AuthorizeRequestProcessor) *AuthorizeFeature {
f.requestProcessor = processor
return f
}
func (f *AuthorizeFeature) ErrorHandler(errorHandler *auth.OAuth2ErrorHandler) *AuthorizeFeature {
f.errorHandler = errorHandler
return f
}
func (f *AuthorizeFeature) AuthorizeHanlder(authHanlder auth.AuthorizeHandler) *AuthorizeFeature {
f.authorizeHandler = authHanlder
return f
}
func (f *AuthorizeFeature) ApprovalStore(store auth.ApprovalStore) *AuthorizeFeature {
f.approvalStore = store
return f
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authorize
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"strconv"
"strings"
)
const (
sessionKeyAuthorizeRequest = "kAuthorizeRequest"
scopeParamPrefix = "scope."
)
/***********************
Authorize Endpoint
***********************/
//goland:noinspection GoNameStartsWithPackageName
type AuthorizeEndpointMiddleware struct {
requestProcessor auth.AuthorizeRequestProcessor
authorizeHandler auth.AuthorizeHandler
approveMatcher web.RequestMatcher
approvalStore auth.ApprovalStore
}
//goland:noinspection GoNameStartsWithPackageName
type AuthorizeMWOptions func(*AuthorizeMWOption)
//goland:noinspection GoNameStartsWithPackageName
type AuthorizeMWOption struct {
RequestProcessor auth.AuthorizeRequestProcessor
AuthorizeHandler auth.AuthorizeHandler
ApprovalMatcher web.RequestMatcher
ApprovalStore auth.ApprovalStore
}
func NewAuthorizeEndpointMiddleware(opts ...AuthorizeMWOptions) *AuthorizeEndpointMiddleware {
opt := AuthorizeMWOption{
RequestProcessor: auth.NewAuthorizeRequestProcessor(),
}
for _, optFunc := range opts {
if optFunc != nil {
optFunc(&opt)
}
}
return &AuthorizeEndpointMiddleware{
requestProcessor: opt.RequestProcessor,
authorizeHandler: opt.AuthorizeHandler,
approveMatcher: opt.ApprovalMatcher,
approvalStore: opt.ApprovalStore,
}
}
func (mw *AuthorizeEndpointMiddleware) PreAuthenticateHandlerFunc(condition web.RequestMatcher) gin.HandlerFunc {
return func(ctx *gin.Context) {
if matches, err := condition.MatchesWithContext(ctx, ctx.Request); !matches || err != nil {
return
}
// parse or load request
var request *auth.AuthorizeRequest
var err error
switch approve, e := mw.approveMatcher.MatchesWithContext(ctx, ctx.Request); {
case e == nil && approve:
// approve or deny request
if request, err = mw.loadAuthorizeRequest(ctx); err != nil {
err = oauth2.NewInvalidAuthorizeRequestError("error loading authorize request for approval", e)
}
default:
if request, err = auth.ParseAuthorizeRequest(ctx.Request); err != nil {
err = oauth2.NewInvalidAuthorizeRequestError("invalid authorize request", e)
}
}
if err != nil {
mw.handleError(ctx, err)
return
}
ctx.Set(oauth2.CtxKeyReceivedAuthorizeRequest, request)
// validate and process, regardless the result, we might want to transfer some context from request to current context
processed, e := mw.requestProcessor.Process(ctx, request)
if e != nil {
mw.transferContextValues(request.Context(), ctx)
mw.handleError(ctx, e)
return
}
// everything is ok, set it to context for later usage
mw.transferContextValues(processed.Context(), ctx)
ctx.Set(oauth2.CtxKeyValidatedAuthorizeRequest, processed)
}
}
func (mw *AuthorizeEndpointMiddleware) AuthorizeHandlerFunc(condition web.RequestMatcher) gin.HandlerFunc {
return func(ctx *gin.Context) {
if matches, err := condition.MatchesWithContext(ctx, ctx.Request); !matches || err != nil {
return
}
// sanity checks
request, client, user, e := mw.endpointSanityCheck(ctx)
if e != nil {
mw.handleError(ctx, e)
return
}
logger.WithContext(ctx).Debug(fmt.Sprintf("AuthorizeRequest: %s", request))
// check auto-approval and create response
var respFunc auth.ResponseHandlerFunc
e = auth.ValidateAllAutoApprovalScopes(ctx, client, request.Scopes)
needsApproval := false
if e != nil {
needsApproval = !mw.hasSavedApproval(ctx, user, request)
}
if needsApproval {
// save request
if e := mw.saveAuthorizeRequest(ctx, request); e != nil {
mw.handleError(ctx, e)
return
}
respFunc, e = mw.authorizeHandler.HandleApprovalPage(ctx, request, user)
} else {
respFunc, e = mw.authorizeHandler.HandleApproved(ctx, request, user)
}
if e != nil {
mw.handleError(ctx, e)
return
}
mw.handleSuccess(ctx, respFunc)
}
}
func (mw *AuthorizeEndpointMiddleware) ApproveOrDenyHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
// no matter what happen, this is the last step. so clear saved request after done
defer func() { _ = mw.clearAuthorizeRequest(ctx) }()
// sanity checks
request, client, user, e := mw.endpointSanityCheck(ctx)
if e != nil {
mw.handleError(ctx, e)
return
}
logger.WithContext(ctx).Debug(fmt.Sprintf("AuthorizeRequest: %s", request))
// parse approval params and check
approval := mw.parseApproval(ctx)
if e := auth.ValidateApproval(ctx, approval, client, request.Scopes); e != nil {
mw.handleError(ctx, e)
return
}
request.Approved = true
_ = mw.saveApprovedRequest(ctx, user, request)
// write response
respFunc, e := mw.authorizeHandler.HandleApproved(ctx, request, user)
if e != nil {
mw.handleError(ctx, e)
return
}
mw.handleSuccess(ctx, respFunc)
}
}
func (mw *AuthorizeEndpointMiddleware) handleSuccess(c *gin.Context, v interface{}) {
switch v.(type) {
case auth.ResponseHandlerFunc:
v.(auth.ResponseHandlerFunc)(c)
c.Abort()
default:
c.JSON(200, v)
c.Abort()
}
}
func (mw *AuthorizeEndpointMiddleware) handleError(c *gin.Context, err error) {
if !errors.Is(err, oauth2.ErrorTypeOAuth2) {
err = oauth2.NewInvalidAuthorizeRequestError(err)
}
_ = mw.clearAuthorizeRequest(c)
_ = c.Error(err)
c.Abort()
}
func (mw *AuthorizeEndpointMiddleware) saveAuthorizeRequest(ctx *gin.Context, request *auth.AuthorizeRequest) error {
s := session.Get(ctx)
if s == nil {
return oauth2.NewInternalError("failed to save authorize request for approval")
}
s.Set(sessionKeyAuthorizeRequest, request)
if e := s.Save(); e != nil {
return oauth2.NewInternalError("failed to save authorize request for approval", e)
}
return nil
}
func (mw *AuthorizeEndpointMiddleware) loadAuthorizeRequest(ctx *gin.Context) (*auth.AuthorizeRequest, error) {
s := session.Get(ctx)
if s == nil {
return nil, oauth2.NewInternalError("failed to load authorize request for approval (no session)")
}
if request, ok := s.Get(sessionKeyAuthorizeRequest).(*auth.AuthorizeRequest); ok {
return request.WithContext(context.Background()), nil
}
return nil, oauth2.NewInternalError("failed to load authorize request for approval")
}
func (mw *AuthorizeEndpointMiddleware) clearAuthorizeRequest(ctx *gin.Context) error {
s := session.Get(ctx)
if s == nil {
return oauth2.NewInternalError("failed to clear authorize request for approval (no session)")
}
s.Delete(sessionKeyAuthorizeRequest)
if e := s.Save(); e != nil {
return oauth2.NewInternalError("failed to clear authorize request for approval", e)
}
return nil
}
func (mw *AuthorizeEndpointMiddleware) saveApprovedRequest(ctx *gin.Context, u security.Authentication, r *auth.AuthorizeRequest) error {
if mw.approvalStore == nil {
// no approval store is provided, nothing to save
return nil
}
if !r.Approved {
return oauth2.NewInternalError("attempting to save unapproved request")
}
approval := &auth.Approval{
ClientId: r.ClientId,
RedirectUri: r.RedirectUri,
Scopes: r.Scopes,
}
userAccount, ok := u.Principal().(security.Account)
if ok {
approval.Username = userAccount.Username()
approval.UserId = userAccount.ID()
} else {
username, e := security.GetUsername(u)
if e != nil {
return oauth2.NewInternalError("can't save approval without user id or username")
}
approval.Username = username
}
if e := mw.approvalStore.SaveApproval(ctx, approval); e != nil {
return oauth2.NewInternalError("failed to save approved request", e)
}
return nil
}
func (mw *AuthorizeEndpointMiddleware) endpointSanityCheck(ctx *gin.Context) (
*auth.AuthorizeRequest, oauth2.OAuth2Client, security.Authentication, error) {
request, ok := ctx.Value(oauth2.CtxKeyValidatedAuthorizeRequest).(*auth.AuthorizeRequest)
if !ok {
return nil, nil, nil, oauth2.NewInternalError("authorize request not processed")
}
user := security.Get(ctx)
if user.State() < security.StateAuthenticated {
return nil, nil, nil, oauth2.NewInternalError("authorize endpoint is called without user authentication")
}
// retrieve client from context. It should be populated by pre-auth MW
client := auth.RetrieveAuthenticatedClient(ctx)
if client == nil {
return nil, nil, nil, oauth2.NewInternalError("client is not loaded")
}
return request, client, user, nil
}
func (mw *AuthorizeEndpointMiddleware) parseApproval(ctx *gin.Context) (approval map[string]bool) {
approved := false
approval = make(map[string]bool)
if v, ok := ctx.Request.PostForm[oauth2.ParameterUserApproval]; ok {
approved, _ = strconv.ParseBool(v[len(v)-1])
}
if !approved {
return
}
for k, v := range ctx.Request.PostForm {
if !strings.HasPrefix(k, scopeParamPrefix) {
continue
}
scope := strings.TrimPrefix(k, scopeParamPrefix)
if len(v) > 0 {
approval[scope], _ = strconv.ParseBool(v[len(v)-1])
} else {
approval[scope] = false
}
}
return
}
func (mw *AuthorizeEndpointMiddleware) transferContextValues(src context.Context, dst context.Context) {
mutable := utils.FindMutableContext(dst)
listable, ok := src.(utils.ListableContext)
if !ok || mutable == nil {
return
}
for k, v := range listable.Values() {
mutable.Set(k, v)
}
}
func (mw *AuthorizeEndpointMiddleware) hasSavedApproval(ctx *gin.Context, user security.Authentication, request *auth.AuthorizeRequest) bool {
if mw.approvalStore == nil {
return false
}
opts := []auth.ApprovalLoadOptions{auth.WithClientId(request.ClientId)}
userAccount, ok := user.Principal().(security.Account)
if ok {
opts = append(opts, auth.WithUsername(userAccount.Username()), auth.WithUserId(userAccount.ID()))
} else {
username, e := security.GetUsername(user)
if e != nil {
return false
}
opts = append(opts, auth.WithUsername(username))
}
approvals, err := mw.approvalStore.LoadApprovals(ctx, opts...)
if err != nil {
return false
}
for _, a := range approvals {
if request.RedirectUri == a.RedirectUri &&
request.Scopes.Equals(a.Scopes) {
return true
}
}
return false
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package authorize
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
var logger = log.New("OAuth2.Auth")
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "oauth2 auth - authorize",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newOAuth2AuthorizeEndpointConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
const (
defaultAuthCodeLength = 32
authCodePrefix = "AC"
)
var (
authCodeValidity = 5 * time.Minute
)
/**********************
Abstraction
**********************/
type AuthorizationCodeStore interface {
GenerateAuthorizationCode(ctx context.Context, r *AuthorizeRequest, user security.Authentication) (string, error)
ConsumeAuthorizationCode(ctx context.Context, authCode string, onetime bool) (oauth2.Authentication, error)
}
/**********************
Redis Impl
**********************/
// RedisAuthorizationCodeStore store authorization code in Redis
type RedisAuthorizationCodeStore struct {
redisClient redis.Client
}
func NewRedisAuthorizationCodeStore(ctx context.Context, cf redis.ClientFactory, dbIndex int) *RedisAuthorizationCodeStore {
client, e := cf.New(ctx, func(opt *redis.ClientOption) {
opt.DbIndex = dbIndex
})
if e != nil {
panic(e)
}
return &RedisAuthorizationCodeStore{
redisClient: client,
}
}
func (s *RedisAuthorizationCodeStore) GenerateAuthorizationCode(ctx context.Context, r *AuthorizeRequest, user security.Authentication) (string, error) {
// code_challenge_method and code_challenge is stored in both parameters and extensions.
// so no need to save them separately
request := r.OAuth2Request()
userAuth := ConvertToOAuthUserAuthentication(user)
toSave := oauth2.NewAuthentication(func(conf *oauth2.AuthOption) {
conf.Request = request
conf.UserAuth = userAuth
})
code := utils.RandomStringWithCharset(defaultAuthCodeLength, utils.CharsetAlphabetic)
if e := s.save(ctx, code, toSave); e != nil {
return "", oauth2.NewInternalError(e)
}
return code, nil
}
func (s *RedisAuthorizationCodeStore) ConsumeAuthorizationCode(ctx context.Context, authCode string, onetime bool) (oauth2.Authentication, error) {
key := s.authCodeRedisKey(authCode)
cmd := s.redisClient.Get(ctx, key)
if cmd.Err() != nil {
return nil, oauth2.NewInvalidGrantError(fmt.Sprintf("code [%s] is not valid", authCode))
}
toLoad := oauth2.NewAuthentication(func(opt *oauth2.AuthOption) {
opt.Request = oauth2.NewOAuth2Request()
opt.UserAuth = oauth2.NewUserAuthentication()
opt.Details = map[string]interface{}{}
})
e := json.Unmarshal([]byte(cmd.Val()), &toLoad)
if e != nil {
return nil, oauth2.NewInvalidGrantError(fmt.Sprintf("code [%s] is not valid", authCode), e)
}
if onetime {
if cmd := s.redisClient.Del(ctx, key); cmd.Err() != nil {
logger.WithContext(ctx).Warnf("authorization code was not removed: %v", cmd.Err())
}
}
return toLoad, nil
}
/**********************
Helpers
**********************/
func (s *RedisAuthorizationCodeStore) save(ctx context.Context, code string, oauth oauth2.Authentication) error {
key := s.authCodeRedisKey(code)
toSave, e := json.Marshal(oauth)
if e != nil {
return e
}
cmd := s.redisClient.Set(ctx, key, toSave, authCodeValidity)
return cmd.Err()
}
func (s *RedisAuthorizationCodeStore) authCodeRedisKey(code string) string {
return fmt.Sprintf("%s:%s", authCodePrefix, code)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web/template"
"github.com/gin-gonic/gin"
"net/http"
)
const (
ApprovalModelKeyAuthRequest = "AuthRequest"
ApprovalModelKeyApprovalUrl = "ApprovalUrl"
)
type ResponseHandlerFunc func(ctx *gin.Context)
type AuthorizeHandler interface {
// HandleApproved makes various ResponseHandlerFunc of authorization based on
// - response_type
// - scope
// - other parameters
// if the implementation decide to not to handle the AuthorizeRequest, returns nil, nil.
// e.g. OIDC impl don't handle non OIDC request and don't handle "code" response type because it's identical from default oauth2 impl
HandleApproved(ctx context.Context, r *AuthorizeRequest, user security.Authentication) (ResponseHandlerFunc, error)
// HandleApprovalPage create ResponseHandlerFunc for user approval page
HandleApprovalPage(ctx context.Context, r *AuthorizeRequest, user security.Authentication) (ResponseHandlerFunc, error)
}
/*************************
Default
*************************/
type AuthHandlerOptions func(opt *AuthHandlerOption)
type AuthHandlerOption struct {
Extensions []AuthorizeHandler
ApprovalPageTmpl string
ApprovalUrl string
AuthService AuthorizationService
AuthCodeStore AuthorizationCodeStore
}
// DefaultAuthorizeHandler implements AuthorizeHandler
// it implement standard OAuth2 responses and keep a list of extensions for additional protocols such as OpenID Connect
type DefaultAuthorizeHandler struct {
extensions []AuthorizeHandler
approvalPageTmpl string
approvalUrl string
authService AuthorizationService
authCodeStore AuthorizationCodeStore
}
func NewAuthorizeHandler(opts ...AuthHandlerOptions) *DefaultAuthorizeHandler {
opt := AuthHandlerOption{
Extensions: []AuthorizeHandler{},
ApprovalPageTmpl: "authorize.tmpl",
}
for _, f := range opts {
f(&opt)
}
order.SortStable(opt.Extensions, order.OrderedFirstCompare)
return &DefaultAuthorizeHandler{
extensions: opt.Extensions,
approvalPageTmpl: opt.ApprovalPageTmpl,
approvalUrl: opt.ApprovalUrl,
authService: opt.AuthService,
authCodeStore: opt.AuthCodeStore,
}
}
func (h *DefaultAuthorizeHandler) Extend(makers ...AuthorizeHandler) *DefaultAuthorizeHandler {
h.extensions = append(h.extensions, makers...)
order.SortStable(h.extensions, order.OrderedFirstCompare)
return h
}
func (h *DefaultAuthorizeHandler) HandleApproved(ctx context.Context, r *AuthorizeRequest, user security.Authentication) (ResponseHandlerFunc, error) {
userAuth := ConvertToOAuthUserAuthentication(user)
// common handling, those common handling could also added as extensions
h.recordSessionId(ctx, userAuth)
for _, delegate := range h.extensions {
if f, e := delegate.HandleApproved(ctx, r, userAuth); f != nil || e != nil {
return f, e
}
}
switch {
case r.ResponseTypes.Has("token"):
return h.MakeImplicitResponse(ctx, r, userAuth)
case r.ResponseTypes.Has("code"):
return h.MakeAuthCodeResponse(ctx, r, userAuth)
default:
return nil, oauth2.NewInvalidResponseTypeError(fmt.Sprintf("response_type [%v] is not supported", r.ResponseTypes.Values()))
}
}
func (h *DefaultAuthorizeHandler) HandleApprovalPage(ctx context.Context, r *AuthorizeRequest, user security.Authentication) (ResponseHandlerFunc, error) {
for _, delegate := range h.extensions {
if f, e := delegate.HandleApprovalPage(ctx, r, user); f != nil || e != nil {
return f, e
}
}
//nolint:contextcheck // false positive
return func(gc *gin.Context) {
mv := template.ModelView{
View: h.approvalPageTmpl,
Model: map[string]interface{}{
ApprovalModelKeyAuthRequest: r,
ApprovalModelKeyApprovalUrl: h.approvalUrl,
},
}
_ = template.TemplateEncodeResponseFunc(gc, gc.Writer, &mv)
}, nil
}
func (h *DefaultAuthorizeHandler) MakeAuthCodeResponse(ctx context.Context, r *AuthorizeRequest, user oauth2.UserAuthentication) (ResponseHandlerFunc, error) {
code, e := h.authCodeStore.GenerateAuthorizationCode(ctx, r, user)
if e != nil {
return nil, e
}
logger.WithContext(ctx).Debug("authorization_code=" + code)
values := map[string]string{
oauth2.ParameterAuthCode: code,
}
redirect, e := composeRedirectUrl(ctx, r, values, false)
if e != nil {
return nil, e
}
return func(c *gin.Context) {
c.Redirect(http.StatusFound, redirect)
}, nil
}
func (h *DefaultAuthorizeHandler) MakeImplicitResponse(ctx context.Context, r *AuthorizeRequest, user oauth2.UserAuthentication) (ResponseHandlerFunc, error) {
//TODO implement Implicit grant
panic("implicit response is not implemented")
}
/*
************************
Helpers
************************
*/
func (h *DefaultAuthorizeHandler) recordSessionId(ctx context.Context, user oauth2.UserAuthentication) {
s := session.Get(ctx)
if s == nil {
return
}
user.DetailsMap()[security.DetailsKeySessionId] = s.GetID()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"net/http"
"net/url"
"strings"
)
type AuthorizeRequest struct {
Parameters map[string]string
ClientId string
ResponseTypes utils.StringSet
Scopes utils.StringSet
RedirectUri string
State string
Extensions map[string]interface{}
Approved bool
context utils.MutableContext
// resource IDs is removed from OAuth2 Specs.
// For backward compatibility, we use client's registered values or hard code it to "nfv-api"
}
func (r *AuthorizeRequest) Context() utils.MutableContext {
return r.context
}
func (r *AuthorizeRequest) WithContext(ctx context.Context) *AuthorizeRequest {
r.context = utils.MakeMutableContext(ctx)
return r
}
func (r *AuthorizeRequest) OAuth2Request() oauth2.OAuth2Request {
return oauth2.NewOAuth2Request(func(details *oauth2.RequestDetails) {
if grantType, ok := r.Parameters[oauth2.ParameterGrantType]; ok {
details.GrantType = grantType
}
details.Parameters = r.Parameters
details.ClientId = r.ClientId
details.Scopes = r.Scopes
details.Approved = true
details.RedirectUri = r.RedirectUri
details.ResponseTypes = r.ResponseTypes
details.Extensions = r.Extensions
})
}
func (r *AuthorizeRequest) String() string {
return fmt.Sprintf("[client=%s, response_type=%s, redirect=%s, scope=%s, ext=%s]",
r.ClientId, r.ResponseTypes, r.RedirectUri, r.Scopes, r.Extensions)
}
func NewAuthorizeRequest(opts ...func(req *AuthorizeRequest)) *AuthorizeRequest {
ar := AuthorizeRequest{
Parameters: map[string]string{},
ResponseTypes: utils.NewStringSet(),
Scopes: utils.NewStringSet(),
Extensions: map[string]interface{}{},
context: utils.NewMutableContext(context.Background()),
}
for _, fn := range opts {
fn(&ar)
}
return &ar
}
func ParseAuthorizeRequest(req *http.Request) (*AuthorizeRequest, error) {
if err := req.ParseForm(); err != nil {
return nil, err
}
values := flattenValuesToMap(req.Form)
return ParseAuthorizeRequestWithKVs(req.Context(), values)
}
func ParseAuthorizeRequestWithKVs(ctx context.Context, values map[string]interface{}) (*AuthorizeRequest, error) {
return &AuthorizeRequest{
Parameters: toStringMap(values),
ClientId: extractStringParam(oauth2.ParameterClientId, values),
ResponseTypes: extractStringSetParam(oauth2.ParameterResponseType, " ", values),
Scopes: extractStringSetParam(oauth2.ParameterScope, " ", values),
RedirectUri: extractStringParam(oauth2.ParameterRedirectUri, values),
State: extractStringParam(oauth2.ParameterState, values),
Extensions: values,
context: utils.MakeMutableContext(ctx),
}, nil
}
/************************
Helpers
************************/
func flattenValuesToMap(src url.Values) (dest map[string]interface{}) {
dest = map[string]interface{}{}
for k, v := range src {
if len(v) == 0 {
continue
}
dest[k] = strings.Join(v, " ")
}
return
}
func toStringMap(src map[string]interface{}) (dest map[string]string) {
dest = map[string]string{}
for k, v := range src {
switch v.(type) {
case string:
dest[k] = v.(string)
case fmt.Stringer:
dest[k] = v.(fmt.Stringer).String()
}
}
return
}
func extractStringParam(key string, params map[string]interface{}) string {
if v, ok := params[key]; ok {
delete(params, key)
return v.(string)
}
return ""
}
func extractStringSetParam(key, sep string, params map[string]interface{}) utils.StringSet {
if v, ok := params[key]; ok {
delete(params, key)
return utils.NewStringSet(strings.Split(v.(string), sep)...)
}
return utils.NewStringSet()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
)
/*****************************
Abstraction
*****************************/
// AuthorizeRequestProcessor validate and process incoming request
// AuthorizeRequestProcessor is the entry point interface for other components to use
type AuthorizeRequestProcessor interface {
Process(ctx context.Context, request *AuthorizeRequest) (processed *AuthorizeRequest, err error)
}
// AuthorizeRequestProcessChain invoke index processor in the processing chain
type AuthorizeRequestProcessChain interface {
Next(ctx context.Context, request *AuthorizeRequest) (processed *AuthorizeRequest, err error)
}
// ChainedAuthorizeRequestProcessor validate and process incoming request and manually invoke index processor in the chain.
type ChainedAuthorizeRequestProcessor interface {
Process(ctx context.Context, request *AuthorizeRequest, chain AuthorizeRequestProcessChain) (validated *AuthorizeRequest, err error)
}
/*****************************
Common Implementations
*****************************/
// authorizeRequestProcessor implements AuthorizeRequestProcessor
type authorizeRequestProcessor struct {
delegates []ChainedAuthorizeRequestProcessor
}
func NewAuthorizeRequestProcessor(delegates ...ChainedAuthorizeRequestProcessor) AuthorizeRequestProcessor {
return &authorizeRequestProcessor{delegates: delegates}
}
func (p *authorizeRequestProcessor) Process(ctx context.Context, request *AuthorizeRequest) (processed *AuthorizeRequest, err error) {
chain := arProcessChain{delegates: p.delegates}
return chain.Next(ctx, request)
}
// arProcessChain implements AuthorizeRequestProcessChain
type arProcessChain struct {
index int
delegates []ChainedAuthorizeRequestProcessor
}
func (c arProcessChain) Next(ctx context.Context, request *AuthorizeRequest) (processed *AuthorizeRequest, err error) {
if c.index >= len(c.delegates) {
return request, nil
}
next := c.delegates[c.index]
c.index++
return next.Process(ctx, request, c)
}
//func (c *authorizeRequestProcessor) Add(processors ...ChainedAuthorizeRequestProcessor) {
// c.delegates = append(c.delegates, flattenProcessors(processors)...)
// // resort the extensions
// order.SortStable(c.delegates, order.OrderedFirstCompare)
//}
//
//func (c *authorizeRequestProcessor) Remove(processor ChainedAuthorizeRequestProcessor) {
// for i, item := range c.delegates {
// if item != processor {
// continue
// }
//
// // remove but keep order
// if i+1 <= len(c.delegates) {
// copy(c.delegates[i:], c.delegates[i+1:])
// }
// c.delegates = c.delegates[:len(c.delegates)-1]
// return
// }
//}
//
//// flattenProcessors recursively flatten any nested NestedAuthorizeRequestProcessor
//func flattenProcessors(processors []ChainedAuthorizeRequestProcessor) (ret []ChainedAuthorizeRequestProcessor) {
// ret = make([]ChainedAuthorizeRequestProcessor, 0, len(processors))
// for _, e := range processors {
// switch e.(type) {
// case *authorizeRequestProcessor:
// flattened := flattenProcessors(e.(*authorizeRequestProcessor).delegates)
// ret = append(ret, flattened...)
// default:
// ret = append(ret, e)
// }
// }
// return
//}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
)
var (
redirectGrantTypes = []string{oauth2.GrantTypeAuthCode, oauth2.GrantTypeImplicit}
supportedResponseTypes = utils.NewStringSet("token", "code")
)
// StandardAuthorizeRequestProcessor implements ChainedAuthorizeRequestProcessor and order.Ordered
// it validate auth request against standard oauth2 specs
type StandardAuthorizeRequestProcessor struct {
clientStore oauth2.OAuth2ClientStore
accountStore security.AccountStore
}
type StdARPOptions func(*StdARPOption)
type StdARPOption struct {
ClientStore oauth2.OAuth2ClientStore
AccountStore security.AccountStore
}
func NewStandardAuthorizeRequestProcessor(opts ...StdARPOptions) *StandardAuthorizeRequestProcessor {
opt := StdARPOption{}
for _, f := range opts {
f(&opt)
}
return &StandardAuthorizeRequestProcessor{
clientStore: opt.ClientStore,
accountStore: opt.AccountStore,
}
}
func (p *StandardAuthorizeRequestProcessor) Process(ctx context.Context, request *AuthorizeRequest, chain AuthorizeRequestProcessChain) (validated *AuthorizeRequest, err error) {
if e := p.validateResponseTypes(ctx, request); e != nil {
return nil, e
}
client, e := p.validateClientId(ctx, request)
if e != nil {
return nil, e
}
request.Context().Set(oauth2.CtxKeyAuthenticatedClient, client)
if e := p.validateRedirectUri(ctx, request, client); e != nil {
return nil, e
}
// starting from this point, we know that redirect uri can be used
request.Context().Set(oauth2.CtxKeyResolvedAuthorizeRedirect, request.RedirectUri)
if request.State != "" {
request.Context().Set(oauth2.CtxKeyResolvedAuthorizeState, request.State)
}
if e := p.validateScope(ctx, request, client); e != nil {
return nil, e
}
return chain.Next(ctx, request)
}
func (p *StandardAuthorizeRequestProcessor) validateResponseTypes(ctx context.Context, request *AuthorizeRequest) error {
return ValidateResponseTypes(ctx, request, supportedResponseTypes)
}
func (p *StandardAuthorizeRequestProcessor) validateClientId(ctx context.Context, request *AuthorizeRequest) (oauth2.OAuth2Client, error) {
return LoadAndValidateClientId(ctx, request.ClientId, p.clientStore)
}
func (p *StandardAuthorizeRequestProcessor) validateRedirectUri(ctx context.Context, request *AuthorizeRequest, client oauth2.OAuth2Client) error {
// first, we check for client's grant type to see if redirect URI is allowed
if client.GrantTypes() == nil || len(client.GrantTypes()) == 0 {
return oauth2.NewInvalidAuthorizeRequestError("client must have at least one authorized grant type")
}
found := false
for _, grant := range redirectGrantTypes {
found = found || client.GrantTypes().Has(grant)
}
if !found {
return oauth2.NewInvalidAuthorizeRequestError(
"redirect_uri can only be used by implicit or authorization_code grant types")
}
// Resolve redirect URI
// The resolved redirect URI is either the redirect_uri from the parameters or the one from
// clientDetails. Either way we need to store it on the AuthorizationRequest.
redirect, e := ResolveRedirectUri(ctx, request.RedirectUri, client)
if e != nil {
return e
}
request.RedirectUri = redirect
return nil
}
func (p *StandardAuthorizeRequestProcessor) validateScope(ctx context.Context, request *AuthorizeRequest, client oauth2.OAuth2Client) error {
if request.Scopes == nil || len(request.Scopes) == 0 {
request.Scopes = client.Scopes().Copy()
} else if e := ValidateAllScopes(ctx, client, request.Scopes); e != nil {
return e
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package claims
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"reflect"
"time"
)
func ClientId(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Source.OAuth2Request() == nil {
return nil, errorMissingRequest
}
return nonZeroOrError(opt.Source.OAuth2Request().ClientId(), errorMissingDetails)
}
func Audience(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Source.OAuth2Request() == nil {
return nil, errorMissingRequest
}
if opt.Source.OAuth2Request().ClientId() == "" {
return nil, errorMissingDetails
}
return utils.NewStringSet(opt.Source.OAuth2Request().ClientId()), nil
}
func JwtId(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
return extractAccessTokenClaim(opt, oauth2.ClaimJwtId)
}
func ExpiresAt(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Source.AccessToken() != nil {
v = opt.Source.AccessToken().ExpiryTime()
}
if details, ok := opt.Source.Details().(security.ContextDetails); ok {
v = details.ExpiryTime()
}
return nonZeroOrError(v, errorMissingDetails)
}
func IssuedAt(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Source.AccessToken() != nil {
v = opt.Source.AccessToken().IssueTime()
}
if details, ok := opt.Source.Details().(security.ContextDetails); ok {
v = details.IssueTime()
}
return nonZeroOrError(v, errorMissingDetails)
}
func Issuer(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Issuer != nil {
if id := opt.Issuer.Identifier(); id != "" {
return id, nil
}
}
// fall back to extract from access token
return extractAccessTokenClaim(opt, oauth2.ClaimIssuer)
}
func NotBefore(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
return extractAccessTokenClaim(opt, oauth2.ClaimNotBefore)
}
func Subject(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
return Username(ctx, opt)
}
func Scopes(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Source.OAuth2Request() == nil {
return nil, errorMissingRequest
}
return nonZeroOrError(opt.Source.OAuth2Request().Scopes(), errorMissingDetails)
}
func Username(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Source.UserAuthentication() == nil || opt.Source.UserAuthentication().Principal() == nil {
return nil, errorMissingUser
}
username, e := security.GetUsername(opt.Source.UserAuthentication())
if e != nil {
return nil, errorMissingUser
}
return nonZeroOrError(username, errorMissingDetails)
}
func nonZeroOrError(v interface{}, candidateError error) (interface{}, error) {
var isZero bool
switch v.(type) {
case string:
isZero = v.(string) == ""
case time.Time:
isZero = v.(time.Time).IsZero()
case utils.StringSet:
isZero = len(v.(utils.StringSet)) == 0
default:
isZero = reflect.ValueOf(v).IsZero()
}
if isZero {
return nil, candidateError
}
return v, nil
}
func extractAccessToken(opt *FactoryOption) oauth2.AccessToken {
token := opt.AccessToken
if token == nil {
token = opt.Source.AccessToken()
}
return token
}
func extractAccessTokenClaim(opt *FactoryOption, claim string) (v interface{}, err error) {
container, ok := extractAccessToken(opt).(oauth2.ClaimsContainer)
if !ok || container.Claims() == nil {
return nil, errorMissingToken
}
claims := container.Claims()
if !claims.Has(claim) {
return nil, errorMissingClaims
}
return claims.Get(claim), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package claims
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
)
func UserId(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.UserId(), errorMissingDetails)
}
func AccountType(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.AccountType().String(), errorMissingDetails)
}
func Currency(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.CurrencyCode(), errorMissingDetails)
}
func DefaultTenantId(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
acct := tryReloadAccount(ctx, opt)
tenancy, ok := acct.(security.AccountTenancy)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(tenancy.DefaultDesignatedTenantId(), errorMissingDetails)
}
func TenantId(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.TenantDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.TenantId(), errorMissingDetails)
}
func TenantExternalId(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.TenantDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.TenantExternalId(), errorMissingDetails)
}
func TenantSuspended(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.TenantDetails)
if !ok {
return nil, errorMissingDetails
}
return utils.BoolPtr(details.TenantSuspended()), nil
}
func ProviderId(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.ProviderDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.ProviderId(), errorMissingDetails)
}
func ProviderName(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.ProviderDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.ProviderName(), errorMissingDetails)
}
func ProviderDisplayName(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.ProviderDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.ProviderDisplayName(), errorMissingDetails)
}
func ProviderDescription(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.ProviderDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.ProviderDescription(), errorMissingDetails)
}
func ProviderEmail(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.ProviderDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.ProviderEmail(), errorMissingDetails)
}
func ProviderNotificationType(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.ProviderDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.ProviderNotificationType(), errorMissingDetails)
}
func Roles(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.AuthenticationDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.Roles(), errorMissingDetails)
}
func Permissions(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.AuthenticationDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.Permissions(), errorMissingDetails)
}
func OriginalUsername(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.ProxiedUserDetails)
if !ok {
return nil, errorMissingDetails
}
if details.Proxied() {
return nonZeroOrError(details.OriginalUsername(), errorMissingDetails)
} else {
return nil, errorMissingDetails
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package claims
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
)
const (
errTmplCreateClaimFailed = `unable to create claim [%s]: %v`
)
var (
errorInvalidSpec = errors.New("invalid claim spec")
errorMissingToken = errors.New("source authentication is missing valid token")
errorMissingRequest = errors.New("source authentication is missing OAuth2 request")
errorMissingUser = errors.New("source authentication is missing user")
errorMissingDetails = errors.New("source authentication is missing required details")
errorMissingClaims = errors.New("source authentication is missing required token claims")
errorMissingRequestParams = errors.New("source authentication's OAuth2 request is missing parameters")
)
type ClaimFactoryFunc func(ctx context.Context, opt *FactoryOption) (v interface{}, err error)
type ClaimRequirementFunc func(ctx context.Context, opt *FactoryOption) bool
type FactoryOptions func(opt *FactoryOption)
type FactoryOption struct {
Specs []map[string]ClaimSpec
Source oauth2.Authentication
Issuer security.Issuer
AccountStore security.AccountStore
AccessToken oauth2.AccessToken
RequestedClaims RequestedClaims
ClaimsFormula []map[string]ClaimSpec
ExtraSource map[string]interface{}
}
func WithSpecs(specs ...map[string]ClaimSpec) FactoryOptions {
return func(opt *FactoryOption) {
opt.Specs = append(opt.Specs, specs...)
}
}
func WithRequestedClaims(requested RequestedClaims, formula ...map[string]ClaimSpec) FactoryOptions {
return func(opt *FactoryOption) {
opt.RequestedClaims = requested
opt.ClaimsFormula = formula
}
}
// WithSource is a FactoryOptions
func WithSource(oauth oauth2.Authentication) FactoryOptions {
return func(opt *FactoryOption) {
opt.Source = oauth
}
}
func WithIssuer(issuer security.Issuer) FactoryOptions {
return func(opt *FactoryOption) {
opt.Issuer = issuer
}
}
func WithAccountStore(accountStore security.AccountStore) FactoryOptions {
return func(opt *FactoryOption) {
opt.AccountStore = accountStore
}
}
func WithAccessToken(token oauth2.AccessToken) FactoryOptions {
return func(opt *FactoryOption) {
opt.AccessToken = token
}
}
func WithExtraSource(extra map[string]interface{}) FactoryOptions {
return func(opt *FactoryOption) {
opt.ExtraSource = extra
}
}
func Populate(ctx context.Context, claims oauth2.Claims, opts ...FactoryOptions) error {
opt := FactoryOption{}
for _, fn := range opts {
fn(&opt)
}
// populate based on specs
for _, specs := range opt.Specs {
if e := populateWithSpecs(ctx, claims, specs, &opt, nil); e != nil {
return e
}
}
// populate based on requested claims.
if opt.RequestedClaims == nil {
return nil
}
for _, specs := range opt.ClaimsFormula {
filter := func(name string, spec ClaimSpec) bool {
requested, ok := opt.RequestedClaims.Get(name)
return !ok || !requested.Essential()
}
if e := populateWithSpecs(ctx, claims, specs, &opt, filter); e != nil {
return e
}
}
return nil
}
type claimSpecFilter func(name string, spec ClaimSpec) (exclude bool)
func populateWithSpecs(ctx context.Context, claims oauth2.Claims, specs map[string]ClaimSpec, opt *FactoryOption, filter claimSpecFilter) error {
for c, spec := range specs {
if c == "" || filter != nil && filter(c, spec) {
continue
}
v, e := spec.Calculate(ctx, opt)
if e != nil && spec.Required(ctx, opt) {
return fmt.Errorf(errTmplCreateClaimFailed, c, e)
} else if e != nil {
continue
}
// check type and assign
if e := safeSet(claims, c, v); e != nil {
return e
}
}
return nil
}
func safeSet(claims oauth2.Claims, claim string, value interface{}) (err error) {
defer func() {
r := recover()
if r == nil {
return
}
if e, ok := r.(error); ok {
err = fmt.Errorf(errTmplCreateClaimFailed, claim, e)
} else {
err = fmt.Errorf(errTmplCreateClaimFailed, claim, r)
}
}()
claims.Set(claim, value)
return nil
}
/*
************************
helpers
************************
*/
func tryReloadAccount(ctx context.Context, opt *FactoryOption) security.Account {
if acct, ok := ctx.Value(oauth2.CtxKeyAuthenticatedAccount).(security.Account); ok {
return acct
}
if opt.AccountStore == nil {
return nil
}
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil
}
user, e := opt.AccountStore.LoadAccountById(ctx, details.UserId())
if e != nil {
return nil
}
// cache it in context if possible
if mc := utils.FindMutableContext(ctx); mc != nil {
mc.Set(oauth2.CtxKeyAuthenticatedAccount, user)
}
return user
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package claims
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
)
func LegacyAudience(ctx context.Context, opt *FactoryOption) utils.StringSet {
// in the java implementation, Spring uses "aud" for resource IDs which has been deprecated
client, ok := ctx.Value(oauth2.CtxKeyAuthenticatedClient).(oauth2.OAuth2Client)
if !ok || client.ResourceIDs() == nil || len(client.ResourceIDs()) == 0 {
return utils.NewStringSet(oauth2.LegacyResourceId)
}
return client.ResourceIDs()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package claims
import (
"context"
"crypto"
"encoding/base64"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/utils"
"strings"
)
// AddressClaim is defined at https://openid.net/specs/openid-connect-core-1_0.html#AddressClaim
type AddressClaim struct {
Formatted string `json:"formatted,omitempty"`
StreetAddr string `json:"street_address,omitempty"`
City string `json:"locality,omitempty"`
Region string `json:"region,omitempty"`
PostalCode string `json:"postal_code,omitempty"`
Country string `json:"country,omitempty"`
}
func AuthenticationTime(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.AuthenticationDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.AuthenticationTime(), errorMissingDetails)
}
func Nonce(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Source.OAuth2Request() == nil || opt.Source.OAuth2Request().Parameters() == nil {
return nil, errorMissingRequest
}
nonce, _ := opt.Source.OAuth2Request().Parameters()[oauth2.ParameterNonce]
return nonZeroOrError(nonce, errorMissingRequestParams)
}
func AuthContextClassRef(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
if opt.Issuer == nil {
return nil, errorMissingDetails
}
method := extractAuthMethod(opt)
if method == "" {
return nil, errorMissingDetails
}
mfaApplied := extractMFAApplied(opt)
if mfaApplied {
return opt.Issuer.LevelOfAssurance(3), nil
} else {
return opt.Issuer.LevelOfAssurance(2), nil
}
}
func AuthMethodRef(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
methods := make([]string, 0, 2)
if m := authMethodString(extractAuthMethod(opt)); m != "" {
methods = append(methods, m)
}
if extractMFAApplied(opt) {
methods = append(methods, "otp")
}
if len(methods) == 0 {
return nil, errorMissingDetails
}
return methods, nil
}
func AccessTokenHash(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
token := extractAccessToken(opt)
if token == nil || token.Value() == "" {
return nil, errorMissingToken
}
return calculateAccessTokenHash(token.Value())
}
func FullName(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
name := strings.TrimSpace(strings.Join([]string{details.FirstName(), details.LastName()}, " "))
return nonZeroOrError(name, errorMissingDetails)
}
func FirstName(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.FirstName(), errorMissingDetails)
}
func LastName(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.LastName(), errorMissingDetails)
}
func Email(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.Email(), errorMissingDetails)
}
func EmailVerified(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
return utils.ToPtr(strings.TrimSpace(details.Email()) != ""), nil
}
func ZoneInfo(_ context.Context, _ *FactoryOption) (v interface{}, err error) {
// maybe implement this if possible to extract it from locale
return nil, errorMissingDetails
}
func Locale(_ context.Context, opt *FactoryOption) (v interface{}, err error) {
details, ok := opt.Source.Details().(security.UserDetails)
if !ok {
return nil, errorMissingDetails
}
return nonZeroOrError(details.LocaleCode(), errorMissingDetails)
}
func Address(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
acct, ok := tryReloadAccount(ctx, opt).(security.AccountMetadata)
if !ok || acct == nil {
return nil, errorMissingDetails
}
addr := AddressClaim{
Formatted: acct.LocaleCode(),
//StreetAddr: "",
//City: "",
//Region: "",
//PostalCode: "",
//Country: "",
}
return &addr, nil
}
/********************
Helpers
********************/
var (
jwtHashAlgorithms = map[string]crypto.Hash {
"RS256": crypto.SHA256,
"ES256": crypto.SHA256,
"HS256": crypto.SHA256,
"PS256": crypto.SHA256,
"RS384": crypto.SHA384,
"HS384": crypto.SHA384,
"RS512": crypto.SHA512,
"HS512": crypto.SHA512,
}
)
func calculateAccessTokenHash(token string) (string, error) {
// find out hashing algorithm
headers, e := jwt.ParseJwtHeaders(token)
if e != nil {
return "", e
}
tokenAlg, _ := headers["alg"].(string)
alg, ok := jwtHashAlgorithms[tokenAlg]
if !ok || !alg.Available() {
return "", fmt.Errorf(`hash is unsupported for access token with alg="%s"`, tokenAlg)
}
// do hash and take the left half
hash := alg.New()
if _, e := hash.Write([]byte(token)); e != nil {
return "", e
}
leftHalf := hash.Sum(nil)[:hash.Size() / 2]
return base64.RawURLEncoding.EncodeToString(leftHalf), nil
}
func extractAuthMethod(opt *FactoryOption) (ret string) {
if opt.Source.UserAuthentication() == nil {
return
}
userAuth := opt.Source.UserAuthentication()
details, ok := userAuth.Details().(map[string]interface{})
if !ok {
return
}
ret, _ = details[security.DetailsKeyAuthMethod].(string)
return
}
func extractMFAApplied(opt *FactoryOption) (ret bool) {
if opt.Source.UserAuthentication() == nil {
return
}
userAuth := opt.Source.UserAuthentication()
details, ok := userAuth.Details().(map[string]interface{})
if !ok {
return
}
ret, _ = details[security.DetailsKeyMFAApplied].(bool)
return
}
func authMethodString(authMethod string) (ret string) {
switch authMethod {
case security.AuthMethodPassword:
return "password"
case security.AuthMethodExternalSaml:
return "saml"
case security.AuthMethodExternalOpenID:
return "openid"
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package claims
import "context"
type RequestedClaims interface {
Get(claim string) (RequestedClaim, bool)
}
type RequestedClaim interface {
Essential() bool
Values() []string
IsDefault() bool
}
type ClaimSpec interface {
Calculate(ctx context.Context, opt *FactoryOption) (v interface{}, err error)
Required(ctx context.Context, opt *FactoryOption) bool
}
type claimSpec struct {
Func ClaimFactoryFunc
ReqFunc ClaimRequirementFunc
}
func (c claimSpec) Calculate(ctx context.Context, opt *FactoryOption) (v interface{}, err error) {
if c.Func == nil {
return nil, errorInvalidSpec
}
return c.Func(ctx, opt)
}
func (c claimSpec) Required(ctx context.Context, opt *FactoryOption) bool {
if c.ReqFunc == nil {
return false
}
return c.ReqFunc(ctx, opt)
}
func Required(fn ClaimFactoryFunc) ClaimSpec {
return &claimSpec{
Func: fn,
ReqFunc: requiredFunc,
}
}
func Optional(fn ClaimFactoryFunc) ClaimSpec {
return &claimSpec{
Func: fn,
ReqFunc: optionalFunc,
}
}
func RequiredIfParamsExists(fn ClaimFactoryFunc, requestParams ...string) ClaimSpec {
return &claimSpec{
Func: fn,
ReqFunc: func(ctx context.Context, opt *FactoryOption) bool {
if opt.Source.OAuth2Request() == nil || opt.Source.OAuth2Request().Parameters() == nil {
return false
}
req := opt.Source.OAuth2Request()
for _, param := range requestParams {
if _, ok := req.Parameters()[param]; ok {
return true
}
}
return false
},
}
}
func RequiredIfImplicitFlow(fn ClaimFactoryFunc) ClaimSpec {
return &claimSpec{
Func: fn,
ReqFunc: func(ctx context.Context, opt *FactoryOption) bool {
if opt.Source.OAuth2Request() == nil || opt.Source.OAuth2Request().ResponseTypes() == nil {
return false
}
return opt.Source.OAuth2Request().ResponseTypes().Has("token")
},
}
}
func Unsupported() ClaimSpec {
return &claimSpec{
Func: func(_ context.Context, _ *FactoryOption) (v interface{}, err error) {
return nil, errorMissingDetails
},
ReqFunc: optionalFunc,
}
}
func requiredFunc(_ context.Context, _ *FactoryOption) bool {
return true
}
func optionalFunc(_ context.Context, _ *FactoryOption) bool {
return false
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
/***********************************
Default implementation
***********************************/
type ClientDetails struct {
ClientId string
Secret string
GrantTypes utils.StringSet
RedirectUris utils.StringSet
Scopes utils.StringSet
AutoApproveScopes utils.StringSet
AccessTokenValidity time.Duration
RefreshTokenValidity time.Duration
UseSessionTimeout bool
AssignedTenantIds utils.StringSet
ResourceIds utils.StringSet
}
// DefaultOAuth2Client implements security.Account & OAuth2Client
type DefaultOAuth2Client struct {
ClientDetails
}
// deja vu
func NewClient() *DefaultOAuth2Client {
return &DefaultOAuth2Client{}
}
func NewClientWithDetails(clientDetails ClientDetails) *DefaultOAuth2Client {
return &DefaultOAuth2Client{
ClientDetails: clientDetails,
}
}
/** OAuth2Client **/
func (c *DefaultOAuth2Client) ClientId() string {
return c.ClientDetails.ClientId
}
func (c *DefaultOAuth2Client) SecretRequired() bool {
return c.ClientDetails.Secret != ""
}
func (c *DefaultOAuth2Client) Secret() string {
return c.ClientDetails.Secret
}
func (c *DefaultOAuth2Client) GrantTypes() utils.StringSet {
return c.ClientDetails.GrantTypes
}
func (c *DefaultOAuth2Client) RedirectUris() utils.StringSet {
return c.ClientDetails.RedirectUris
}
func (c *DefaultOAuth2Client) Scopes() utils.StringSet {
return c.ClientDetails.Scopes
}
func (c *DefaultOAuth2Client) AutoApproveScopes() utils.StringSet {
return c.ClientDetails.AutoApproveScopes
}
func (c *DefaultOAuth2Client) AccessTokenValidity() time.Duration {
return c.ClientDetails.AccessTokenValidity
}
func (c *DefaultOAuth2Client) RefreshTokenValidity() time.Duration {
return c.ClientDetails.RefreshTokenValidity
}
func (c *DefaultOAuth2Client) UseSessionTimeout() bool {
return c.ClientDetails.UseSessionTimeout
}
func (c *DefaultOAuth2Client) AssignedTenantIds() utils.StringSet {
return c.ClientDetails.AssignedTenantIds
}
func (c *DefaultOAuth2Client) ResourceIDs() utils.StringSet {
return c.ClientDetails.ResourceIds
}
func (c *DefaultOAuth2Client) MaxTokensPerUser() int {
return -1
}
/** security.Account **/
func (c *DefaultOAuth2Client) ID() interface{} {
return c.ClientDetails.ClientId
}
func (c *DefaultOAuth2Client) Type() security.AccountType {
return security.AccountTypeDefault
}
func (c *DefaultOAuth2Client) Username() string {
return c.ClientDetails.ClientId
}
func (c *DefaultOAuth2Client) Credentials() interface{} {
return c.ClientDetails.Secret
}
func (c *DefaultOAuth2Client) Permissions() []string {
return c.ClientDetails.Scopes.Values()
}
func (c *DefaultOAuth2Client) Disabled() bool {
return false
}
func (c *DefaultOAuth2Client) Locked() bool {
return false
}
func (c *DefaultOAuth2Client) UseMFA() bool {
return false
}
func (c *DefaultOAuth2Client) CacheableCopy() security.Account {
copy := DefaultOAuth2Client{
ClientDetails: c.ClientDetails,
}
copy.ClientDetails.Secret = ""
return ©
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
)
/***********************************
default implmentation
***********************************/
// OAuth2ClientAccountStore wraps an delegate and implement both security.AccountStore and client oauth2.OAuth2ClientStore
type OAuth2ClientAccountStore struct {
oauth2.OAuth2ClientStore
}
func WrapOAuth2ClientStore(clientStore oauth2.OAuth2ClientStore) *OAuth2ClientAccountStore {
return &OAuth2ClientAccountStore{
OAuth2ClientStore: clientStore,
}
}
// security.AccountStore
func (s *OAuth2ClientAccountStore) LoadAccountById(ctx context.Context, id interface{}) (security.Account, error) {
switch id.(type) {
case string:
return s.LoadAccountByUsername(ctx, id.(string))
default:
return nil, security.NewUsernameNotFoundError("invalid clientId type")
}
}
// security.AccountStore
func (s *OAuth2ClientAccountStore) LoadAccountByUsername(ctx context.Context, username string) (security.Account, error) {
if client, err := s.OAuth2ClientStore.LoadClientByClientId(ctx, username); err != nil {
return nil, security.NewUsernameNotFoundError("invalid clientId")
} else if acct, ok := client.(security.Account); !ok {
return nil, security.NewInternalAuthenticationError("loaded client is not an account")
} else {
return acct, nil
}
}
// security.AccountStore
func (s *OAuth2ClientAccountStore) LoadLockingRules(ctx context.Context, acct security.Account) (security.AccountLockingRule, error) {
return nil, security.NewInternalAuthenticationError("client doesn't have locking rule")
}
// security.AccountStore
func (s *OAuth2ClientAccountStore) LoadPwdAgingRules(ctx context.Context, acct security.Account) (security.AccountPwdAgingRule, error) {
return nil, security.NewInternalAuthenticationError("client doesn't have aging rule")
}
// security.AccountStore
func (s *OAuth2ClientAccountStore) Save(ctx context.Context, acct security.Account) error {
return security.NewInternalAuthenticationError("client is inmutable during authentication")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package clientauth
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/basicauth"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
var (
FeatureId = security.FeatureId("OAuth2ClientAuth", security.FeatureOrderOAuth2ClientAuth)
)
//goland:noinspection GoNameStartsWithPackageName
type ClientAuthConfigurer struct {
}
func newClientAuthConfigurer() *ClientAuthConfigurer {
return &ClientAuthConfigurer{
}
}
func (c *ClientAuthConfigurer) Apply(feature security.Feature, ws security.WebSecurity) (err error) {
// Verify
f := feature.(*ClientAuthFeature)
if err := c.validate(f, ws); err != nil {
return err
}
// configure other features
passwd.Configure(ws).
AccountStore(c.clientAccountStore(f)).
PasswordEncoder(f.clientSecretEncoder).
MFA(false)
// no entry point, everything handled by access denied handler
basicauth.Configure(ws).
EntryPoint(nil)
access.Configure(ws).
Request(matcher.AnyRequest()).
Authenticated()
errorhandling.Configure(ws).
AdditionalErrorHandler(f.errorHandler)
// add middleware to translate authentication error to oauth2 error
mw := NewClientAuthMiddleware(func(opt *MWOption) {
opt.Authenticator = ws.Authenticator()
opt.SuccessHandler = ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(security.AuthenticationSuccessHandler)
})
ws.Add(middleware.NewBuilder("client auth error translator").
Order(security.MWOrderPreAuth).
Use(mw.ErrorTranslationHandlerFunc()),
)
// add middleware to support form based client auth
if f.allowForm {
ws.Add(middleware.NewBuilder("form client auth").
Order(security.MWOrderFormAuth).
Use(mw.ClientAuthFormHandlerFunc()),
)
}
return nil
}
func (c *ClientAuthConfigurer) validate(f *ClientAuthFeature, ws security.WebSecurity) error {
if f.clientStore == nil {
return fmt.Errorf("client store for client authentication is not set")
}
if f.clientSecretEncoder == nil {
f.clientSecretEncoder = passwd.NewNoopPasswordEncoder()
}
if f.errorHandler == nil {
f.errorHandler = auth.NewOAuth2ErrorHandler()
}
return nil
}
func (c *ClientAuthConfigurer) clientAccountStore(f *ClientAuthFeature) *auth.OAuth2ClientAccountStore {
return auth.WrapOAuth2ClientStore(f.clientStore)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package clientauth
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
)
// We currently don't have any stuff to configure
//goland:noinspection GoNameStartsWithPackageName
type ClientAuthFeature struct {
clientStore oauth2.OAuth2ClientStore
clientSecretEncoder passwd.PasswordEncoder
errorHandler *auth.OAuth2ErrorHandler
allowForm bool
}
// Standard security.Feature entrypoint
func (f *ClientAuthFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
func Configure(ws security.WebSecurity) *ClientAuthFeature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*ClientAuthFeature)
}
panic(fmt.Errorf("unable to configure oauth2 authserver: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *ClientAuthFeature {
return &ClientAuthFeature{
}
}
/** Setters **/
func (f *ClientAuthFeature) ClientStore(clientStore oauth2.OAuth2ClientStore) *ClientAuthFeature {
f.clientStore = clientStore
return f
}
func (f *ClientAuthFeature) ClientSecretEncoder(clientSecretEncoder passwd.PasswordEncoder) *ClientAuthFeature {
f.clientSecretEncoder = clientSecretEncoder
return f
}
func (f *ClientAuthFeature) ErrorHandler(errorHandler *auth.OAuth2ErrorHandler) *ClientAuthFeature {
f.errorHandler = errorHandler
return f
}
// AllowForm with "true" also implicitly enables Public Client (client that with empty secret)
func (f *ClientAuthFeature) AllowForm(allowForm bool) *ClientAuthFeature {
f.allowForm = allowForm
return f
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package clientauth
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"net/http"
)
type Middleware struct {
authenticator security.Authenticator
successHandler security.AuthenticationSuccessHandler
}
type MWOptions func(*MWOption)
type MWOption struct {
Authenticator security.Authenticator
SuccessHandler security.AuthenticationSuccessHandler
}
func NewClientAuthMiddleware(opts...MWOptions) *Middleware {
opt := MWOption{}
for _, optFunc := range opts {
if optFunc != nil {
optFunc(&opt)
}
}
return &Middleware{
authenticator: opt.Authenticator,
successHandler: opt.SuccessHandler,
}
}
func (mw *Middleware) ClientAuthFormHandlerFunc() http.HandlerFunc {
return func(rw http.ResponseWriter, r *http.Request) {
if e := r.ParseForm(); e != nil {
return
}
_, hasClientId := r.Form[oauth2.ParameterClientId]
if !hasClientId {
return
}
clientId := r.Form.Get(oauth2.ParameterClientId)
// form client auth should be placed after basic auth.
// if already authenticated by basic auth and pricipal matches, we don't need to do anything here
// if authenticated but pricipal doesn't match, it's an error
before := security.Get(r.Context())
currentAuth, ok := before.(passwd.UsernamePasswordAuthentication)
switch {
case ok && passwd.IsSamePrincipal(clientId, currentAuth):
return
case ok:
mw.handleError(r.Context(), oauth2.NewInvalidClientError("client_id parameter and Authorization header doesn't match"))
}
secret := r.PostForm.Get(oauth2.ParameterClientSecret)
candidate := passwd.UsernamePasswordPair{
Username: clientId,
Password: secret,
EnforceMFA: passwd.MFAModeSkip,
}
// Authenticate
auth, err := mw.authenticator.Authenticate(r.Context(), &candidate)
if err != nil {
mw.handleError(r.Context(), err)
return
}
mw.handleSuccess(r.Context(), r, rw, before, auth)
}
}
func (mw *Middleware) ErrorTranslationHandlerFunc() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// find first authentication error and translate it
for _, e := range c.Errors {
switch {
case errors.Is(e.Err, security.ErrorTypeAuthentication):
e.Err = oauth2.NewInvalidClientError("client authentication failed", e.Err)
}
}
}
}
func (mw *Middleware) handleSuccess(c context.Context, r *http.Request, rw http.ResponseWriter, before, new security.Authentication) {
gc := web.GinContext(c)
if new != nil {
security.MustSet(c, new)
mw.successHandler.HandleAuthenticationSuccess(c, r, rw, before, new)
}
gc.Next()
}
//nolint:contextcheck
func (mw *Middleware) handleError(c context.Context, err error) {
gc := web.GinContext(c)
security.MustClear(gc)
_ = gc.Error(err)
gc.Abort()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package clientauth
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "oauth2 auth - client auth",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newClientAuthConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
)
/***********************
Common Functions
***********************/
func RetrieveAuthenticatedClient(c context.Context) oauth2.OAuth2Client {
if client, ok := c.Value(oauth2.CtxKeyAuthenticatedClient).(oauth2.OAuth2Client); ok {
return client
}
sec := security.Get(c)
if sec.State() < security.StatePrincipalKnown {
return nil
}
if client, ok := sec.Principal().(oauth2.OAuth2Client); ok {
return client
}
return nil
}
func RetrieveFullyAuthenticatedClient(c context.Context) (oauth2.OAuth2Client, error) {
sec := security.Get(c)
if sec.State() < security.StateAuthenticated {
return nil, oauth2.NewInvalidGrantError("client is not fully authenticated")
}
if client, ok := sec.Principal().(oauth2.OAuth2Client); ok {
return client, nil
}
return nil, oauth2.NewInvalidGrantError("client is not fully authenticated")
}
func ValidateResponseTypes(ctx context.Context, request *AuthorizeRequest, supported utils.StringSet) error {
if request.ResponseTypes == nil {
return oauth2.NewInvalidAuthorizeRequestError("response_type is required")
}
// shortcut if already validated
if v := request.Context().Value(ctxKeyValidResponseType); v != nil {
return nil
}
if ok, invalid := IsSubSet(ctx, supported, request.ResponseTypes); !ok {
return oauth2.NewInvalidResponseTypeError(fmt.Sprintf("unsupported response type: %s", invalid))
}
// mark validated
request.Context().Set(ctxKeyValidResponseType, true)
return nil
}
func ValidateGrant(_ context.Context, client oauth2.OAuth2Client, grantType string) error {
if grantType == "" {
return oauth2.NewInvalidTokenRequestError("missing grant_type")
}
if !client.GrantTypes().Has(grantType) {
return oauth2.NewUnauthorizedClientError(fmt.Sprintf("grant type '%s' is not allowed by this client '%s'", grantType, client.ClientId()))
}
return nil
}
func ValidateScope(c context.Context, client oauth2.OAuth2Client, scopes ...string) error {
for _, scope := range scopes {
if !client.Scopes().Has(scope) {
return oauth2.NewInvalidScopeError("invalid scope: " + scope)
}
}
return nil
}
func ValidateAllScopes(c context.Context, client oauth2.OAuth2Client, scopes utils.StringSet) error {
if ok, invalid := IsSubSet(c, client.Scopes(), scopes); !ok {
return oauth2.NewInvalidScopeError("invalid scope: " + invalid)
}
return nil
}
func ValidateAllAutoApprovalScopes(c context.Context, client oauth2.OAuth2Client, scopes utils.StringSet) error {
if ok, invalid := IsSubSet(c, client.AutoApproveScopes(), scopes); !ok {
return oauth2.NewAccessRejectedError("scope not auto approved: " + invalid)
}
return nil
}
func IsSubSet(_ context.Context, superset utils.StringSet, subset utils.StringSet) (ok bool, invalid string) {
for scope, _ := range subset {
if !superset.Has(scope) {
return false, scope
}
}
return true, ""
}
// ValidateApproval approval param is a map with scope as keys and approval status as values
func ValidateApproval(c context.Context, approval map[string]bool, client oauth2.OAuth2Client, scopes utils.StringSet) error {
if e := ValidateAllScopes(c, client, scopes); e != nil {
return e
}
for scope, _ := range scopes {
if approved, ok := approval[scope]; !ok || !approved {
return oauth2.NewAccessRejectedError(fmt.Sprintf("user disapproved scope [%s]", scope))
}
}
return nil
}
func LoadAndValidateClientId(c context.Context, clientId string, clientStore oauth2.OAuth2ClientStore) (oauth2.OAuth2Client, error) {
if clientId == "" {
return nil, oauth2.NewInvalidAuthorizeRequestError(fmt.Sprintf("A client id must be provided"))
}
client, e := clientStore.LoadClientByClientId(c, clientId)
if e != nil {
return nil, oauth2.NewClientNotFoundError("invalid client")
}
return client, nil
}
func ResolveRedirectUri(_ context.Context, redirectUri string, client oauth2.OAuth2Client) (string, error) {
if client.RedirectUris() == nil || len(client.RedirectUris()) == 0 {
return "", oauth2.NewInvalidAuthorizeRequestError(
"at least one redirectUri must be registered in the client")
}
// The resolved redirect URI is either the redirect_uri from the parameters or the one from
// clientDetails.
if redirectUri == "" && len(client.RedirectUris()) == 1 {
// single registered redirect URI
return client.RedirectUris().Values()[0], nil
} else if redirectUri == "" {
return "", oauth2.NewInvalidRedirectUriError("the redirect_uri must be proveded because the client have multiple registered redirect URI")
}
for registered, _ := range client.RedirectUris() {
matcher, e := NewWildcardUrlMatcher(registered)
if e != nil {
continue
}
if matches, e := matcher.Matches(redirectUri); e == nil && matches {
return redirectUri, nil
}
}
return "", oauth2.NewInvalidRedirectUriError("the redirect_uri must be registered with the client")
}
type ConvertOptions struct {
SkipTypeCheck bool
userAuthOptions []OverrideAuthOptions
}
func (c *ConvertOptions) AppendUserAuthOptions(option OverrideAuthOptions) {
c.userAuthOptions = append(c.userAuthOptions, option)
}
type ConvertOption func(option *ConvertOptions)
func ConvertWithSkipTypeCheck(skipTypeCheck bool) ConvertOption {
return func(option *ConvertOptions) {
option.SkipTypeCheck = skipTypeCheck
}
}
// OverrideAuthOptions allows the oauth2.UserAuthOptions to be overridden during the
// conversion when creating and returning a new user authentication.
type OverrideAuthOptions func(userAuth security.Authentication) oauth2.UserAuthOptions
// ConvertToOAuthUserAuthentication takes any type of authentication and convert it into oauth2.Authentication
func ConvertToOAuthUserAuthentication(userAuth security.Authentication, options ...ConvertOption) oauth2.UserAuthentication {
var opts ConvertOptions
for _, opt := range options {
opt(&opts)
}
if !opts.SkipTypeCheck {
switch ua := userAuth.(type) {
case nil:
return nil
case oauth2.UserAuthentication:
return ua
}
}
principal, e := security.GetUsername(userAuth)
if e != nil {
principal = fmt.Sprintf("%v", userAuth)
}
details, ok := userAuth.Details().(map[string]interface{})
if !ok {
details = map[string]interface{}{
"Literal": userAuth.Details(),
}
}
defaultOption := func(opt *oauth2.UserAuthOption) {
opt.Principal = principal
opt.Permissions = userAuth.Permissions()
opt.State = userAuth.State()
opt.Details = details
}
var wrappedAuthOptions []oauth2.UserAuthOptions
for _, opt := range opts.userAuthOptions {
wrappedOption := opt(userAuth)
wrappedAuthOptions = append(wrappedAuthOptions, wrappedOption)
}
var authenticationOptions []oauth2.UserAuthOptions
authenticationOptions = append(authenticationOptions, defaultOption)
authenticationOptions = append(authenticationOptions, wrappedAuthOptions...)
return oauth2.NewUserAuthentication(authenticationOptions...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"net/http"
)
var (
errorInvalidRedirectUri = oauth2.NewInvalidRedirectUriError("")
)
// OAuth2ErrorHandler implements security.ErrorHandler
// It's responsible to handle all oauth2 errors
type OAuth2ErrorHandler struct {}
func NewOAuth2ErrorHandler() *OAuth2ErrorHandler {
return &OAuth2ErrorHandler{}
}
// HandleError implements security.ErrorHandler
func (h *OAuth2ErrorHandler) HandleError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
h.handleError(c, r, rw, err)
}
func (h *OAuth2ErrorHandler) handleError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
//nolint:errorlint
switch oe, ok := err.(oauth2.OAuth2ErrorTranslator); {
case ok && errors.Is(err, errorInvalidRedirectUri):
writeOAuth2Error(c, r, rw, oe)
case ok && errors.Is(err, oauth2.ErrorSubTypeOAuth2Internal):
fallthrough
case ok && errors.Is(err, oauth2.ErrorTypeOAuth2):
// use redirect uri, fallback to standard error response
tryWriteErrorAsRedirect(c, r, rw, oe)
// No default, give other error handler chance to handle
}
}
func writeOAuth2Error(c context.Context, r *http.Request, rw http.ResponseWriter, err oauth2.OAuth2ErrorTranslator) {
challenge := ""
sc := err.TranslateStatusCode()
if sc == http.StatusUnauthorized || sc == http.StatusForbidden {
challenge = fmt.Sprintf("%s %s", "Bearer", err.Error())
}
writeAdditionalHeader(c, r, rw, challenge)
switch {
case errors.Is(err, errorInvalidRedirectUri):
security.WriteError(c, r, rw, sc, err)
default:
security.WriteErrorAsJson(c, rw, sc, err)
}
}
func writeAdditionalHeader(_ context.Context, _ *http.Request, rw http.ResponseWriter, challenge string) {
if security.IsResponseWritten(rw) {
return
}
rw.Header().Add("Cache-Control", "no-store")
rw.Header().Add("Pragma", "no-cache")
if challenge != "" {
rw.Header().Set("WWW-Authenticate", challenge)
}
}
// given err have to be oauth2.OAuth2ErrorTranslator
func tryWriteErrorAsRedirect(c context.Context, r *http.Request, rw http.ResponseWriter, err oauth2.OAuth2ErrorTranslator) {
if security.IsResponseWritten(rw) {
return
}
params := map[string]string{}
params[oauth2.ParameterError] = err.TranslateErrorCode()
params[oauth2.ParameterErrorDescription] = err.Error()
// TODO support fragment
ar := findAuthorizeRequest(c, r)
redirectUrl, e := composeRedirectUrl(c, ar, params, false)
if e != nil {
// fallback to default
writeOAuth2Error(c, r, rw, err)
return
}
http.Redirect(rw, r, redirectUrl, http.StatusFound)
_, _ = rw.Write([]byte{})
}
func findAuthorizeRequest(c context.Context, _ *http.Request) *AuthorizeRequest {
if ar, ok := c.Value(oauth2.CtxKeyValidatedAuthorizeRequest).(*AuthorizeRequest); ok {
return ar
}
if ar, ok := c.Value(oauth2.CtxKeyReceivedAuthorizeRequest).(*AuthorizeRequest); ok {
return ar
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
)
type TokenGranter interface {
// Grant create oauth2.AccessToken based on given TokenRequest
// returns
// - (nil, nil) if the TokenGranter doesn't support given request
// - (non-nil, nil) if the TokenGranter support given request and created a token without error
// - (nil, non-nil) if the TokenGranter support given request but rejected the request
Grant(ctx context.Context, request *TokenRequest) (oauth2.AccessToken, error)
}
// AuthorizationServiceInjector
// By implementing this interface, a component can ask the framework to call its Inject method
// to get a reference to the AuthorizationService.
// Currently only component that also implements TokenGranter interface will have its Inject method be called.
type AuthorizationServiceInjector interface {
Inject(authService AuthorizationService)
}
// CompositeTokenGranter implements TokenGranter
type CompositeTokenGranter struct {
delegates []TokenGranter
}
func NewCompositeTokenGranter(delegates ...TokenGranter) *CompositeTokenGranter {
return &CompositeTokenGranter{
delegates: delegates,
}
}
func (g *CompositeTokenGranter) Grant(ctx context.Context, request *TokenRequest) (oauth2.AccessToken, error) {
for _, granter := range g.delegates {
if token, e := granter.Grant(ctx, request); e != nil {
return nil, e
} else if token != nil {
return token, nil
}
}
return nil, oauth2.NewGranterNotAvailableError(fmt.Sprintf("grant type [%s] is not supported", request.GrantType))
}
func (g *CompositeTokenGranter) Add(granter TokenGranter) *CompositeTokenGranter {
g.delegates = append(g.delegates, granter)
return g
}
func (g *CompositeTokenGranter) Delegates() []TokenGranter {
return g.delegates
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/utils"
)
var (
authCodeIgnoreParams = utils.NewStringSet(
oauth2.ParameterScope,
oauth2.ParameterClientSecret,
oauth2.ParameterCodeVerifier,
oauth2.ParameterCodeChallenge,
oauth2.ParameterCodeChallengeMethod,
)
)
// AuthorizationCodeGranter implements auth.TokenGranter
type AuthorizationCodeGranter struct {
authService auth.AuthorizationService
authCodeStore auth.AuthorizationCodeStore
}
func NewAuthorizationCodeGranter(authService auth.AuthorizationService, authCodeStore auth.AuthorizationCodeStore) *AuthorizationCodeGranter {
if authService == nil {
panic(fmt.Errorf("cannot create AuthorizationCodeGranter without auth service"))
}
if authCodeStore == nil {
panic(fmt.Errorf("cannot create AuthorizationCodeGranter without auth code service"))
}
return &AuthorizationCodeGranter{
authService: authService,
authCodeStore: authCodeStore,
}
}
func (g *AuthorizationCodeGranter) Grant(ctx context.Context, request *auth.TokenRequest) (oauth2.AccessToken, error) {
if oauth2.GrantTypeAuthCode != request.GrantType {
return nil, nil
}
client := auth.RetrieveAuthenticatedClient(ctx)
// common check
if e := auth.ValidateGrant(ctx, client, request.GrantType); e != nil {
return nil, e
}
// load authentication using auth code
code, ok := request.Extensions[oauth2.ParameterAuthCode].(string)
if !ok || code == "" {
return nil, oauth2.NewInvalidTokenRequestError(fmt.Sprintf("missing required parameter %s", oauth2.ParameterAuthCode))
}
stored, e := g.authCodeStore.ConsumeAuthorizationCode(ctx, code, true)
if e != nil {
return nil, e
} else if !stored.OAuth2Request().Approved() || stored.UserAuthentication() == nil {
return nil, oauth2.NewInvalidGrantError("original authorize request is invalid")
}
// PKCE
if e := validatePKCE(stored.OAuth2Request(), request); e != nil {
return nil, e
}
// check redirect uri
if e := validateRedirectUri(stored.OAuth2Request(), request); e != nil {
return nil, e
}
// check client ID
if stored.OAuth2Request().ClientId() != client.ClientId() {
return nil, oauth2.NewInvalidGrantError("client ID mismatch")
}
// create authentication from stored value
oauthRequest, e := mergedOAuth2Request(stored.OAuth2Request(), request, authCodeIgnoreParams)
if e != nil {
return nil, e
}
oauth, e := g.authService.CreateAuthentication(ctx, oauthRequest, stored.UserAuthentication())
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
// create token
token, e := g.authService.CreateAccessToken(ctx, oauth)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
return token, nil
}
// https://datatracker.ietf.org/doc/html/rfc7636
func validatePKCE(stored oauth2.OAuth2Request, request *auth.TokenRequest) error {
challenge, cOk := stored.Parameters()[oauth2.ParameterCodeChallenge]
verifier, vOk := request.Parameters[oauth2.ParameterCodeVerifier]
if !cOk && !vOk {
return nil
}
switch {
case challenge == "":
return oauth2.NewInvalidTokenRequestError(fmt.Errorf(`unexpected "%s"`, oauth2.ParameterCodeVerifier))
case verifier == "":
return oauth2.NewInvalidTokenRequestError(fmt.Errorf(`"%s" required`, oauth2.ParameterCodeVerifier))
}
str := stored.Parameters()[oauth2.ParameterCodeChallengeMethod]
method, e := parseCodeChallengeMethod(str)
if e != nil {
return oauth2.NewInvalidTokenRequestError(fmt.Errorf(`unsupported code challenge method "%s:"`, str))
}
if !verifyPKCE(verifier, challenge, method) {
return oauth2.NewInvalidTokenRequestError(fmt.Errorf(`invalid "%s"`, oauth2.ParameterCodeVerifier))
}
return nil
}
// https://tools.ietf.org/html/rfc6749#section-4.1.3
// if redirect_uri was provided in original request (not implied from client registrition).
// the same redirect_uri must be provided in token request
func validateRedirectUri(stored oauth2.OAuth2Request, request *auth.TokenRequest) error {
origRedirect, ok := stored.Parameters()[oauth2.ParameterRedirectUri]
if !ok || origRedirect == "" {
// nothing wrong, redirect was not provided, probably implied from client registration
return nil
}
reqRedirect, ok := request.Parameters[oauth2.ParameterRedirectUri]
if !ok {
return oauth2.NewInvalidTokenRequestError("redirect_uri is required because redirect URL was provided when obtaining the auth code")
} else if reqRedirect != origRedirect {
return oauth2.NewInvalidGrantError("redirect_uri doesn't match the original redirect URL used when obtaining the auth code")
}
return nil
}
func mergedOAuth2Request(src oauth2.OAuth2Request, request *auth.TokenRequest, ignoreParams utils.StringSet) (oauth2.OAuth2Request, error) {
return src.NewOAuth2Request(func(opt *oauth2.RequestDetails) {
opt.GrantType = request.GrantType
for k, v := range request.Parameters {
if ignoreParams.Has(k) {
continue
}
opt.Parameters[k] = v
}
for k, v := range request.Extensions {
if ignoreParams.Has(k) {
continue
}
opt.Extensions[k] = v
}
}), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
)
// ClientCredentialsGranter implements auth.TokenGranter
type ClientCredentialsGranter struct {
authService auth.AuthorizationService
}
func NewClientCredentialsGranter(authService auth.AuthorizationService) *ClientCredentialsGranter {
if authService == nil {
panic(fmt.Errorf("cannot create ClientCredentialsGranter without token service."))
}
return &ClientCredentialsGranter{
authService: authService,
}
}
func (g *ClientCredentialsGranter) Grant(ctx context.Context, request *auth.TokenRequest) (oauth2.AccessToken, error) {
if oauth2.GrantTypeClientCredentials != request.GrantType {
return nil, nil
}
// for client credentials grant, client have to be authenticated via client/secret
client, e := auth.RetrieveFullyAuthenticatedClient(ctx)
if e != nil {
return nil, oauth2.NewInvalidGrantError("client_credentials requires client secret validated")
}
// common check
if e := CommonPreGrantValidation(ctx, client, request); e != nil {
return nil, e
}
// additional check
if request.Scopes == nil || len(request.Scopes) == 0 {
request.Scopes = client.Scopes()
}
if e := auth.ValidateAllAutoApprovalScopes(ctx, client, request.Scopes); e != nil {
return nil, e
}
// create authentication
req := request.OAuth2Request(client)
oauth, e := g.authService.CreateAuthentication(ctx, req, nil)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
// create token
token, e := g.authService.CreateAccessToken(ctx, oauth)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
return token, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
)
var logger = log.New("OAuth2.Grant")
func CommonPreGrantValidation(c context.Context, client oauth2.OAuth2Client, request *auth.TokenRequest) error {
// check grant
if e := auth.ValidateGrant(c, client, request.GrantType); e != nil {
return e
}
// check scope
if e := auth.ValidateAllScopes(c, client, request.Scopes); e != nil {
return e
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
)
// PasswordGranter implements auth.TokenGranter
type PasswordGranter struct {
authenticator security.Authenticator
authService auth.AuthorizationService
}
func NewPasswordGranter(authService auth.AuthorizationService, authenticator security.Authenticator) *PasswordGranter {
if authenticator == nil {
panic(fmt.Errorf("cannot create PasswordGranter without authenticator."))
}
if authService == nil {
panic(fmt.Errorf("cannot create PasswordGranter without authorization service."))
}
return &PasswordGranter{
authenticator: authenticator,
authService: authService,
}
}
func (g *PasswordGranter) Grant(ctx context.Context, request *auth.TokenRequest) (oauth2.AccessToken, error) {
if oauth2.GrantTypePassword != request.GrantType {
return nil, nil
}
client := auth.RetrieveAuthenticatedClient(ctx)
// common check
if e := CommonPreGrantValidation(ctx, client, request); e != nil {
return nil, e
}
// extract username & password
username, uOk := request.Parameters[oauth2.ParameterUsername]
password, pOk := request.Parameters[oauth2.ParameterPassword]
delete(request.Parameters, oauth2.ParameterPassword)
if !uOk || !pOk {
return nil, oauth2.NewInvalidGrantError("missing 'username' and 'password'")
}
// authenticate
candidate := passwd.UsernamePasswordPair{
Username: username,
Password: password,
}
userAuth, err := g.authenticator.Authenticate(ctx, &candidate)
if err != nil || userAuth.State() < security.StateAuthenticated {
return nil, oauth2.NewInvalidGrantError(err)
}
// additional check
if request.Scopes == nil || len(request.Scopes) == 0 {
request.Scopes = client.Scopes()
}
if e := auth.ValidateAllAutoApprovalScopes(ctx, client, request.Scopes); e != nil {
return nil, e
}
// create authentication
req := request.OAuth2Request(client)
oauth, e := g.authService.CreateAuthentication(ctx, req, userAuth)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
// create token
token, e := g.authService.CreateAccessToken(ctx, oauth)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
return token, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/tokenauth"
"github.com/cisco-open/go-lanai/pkg/utils"
)
var (
permissionBasedIgnoreParams = utils.NewStringSet(
oauth2.ParameterClientSecret,
oauth2.ParameterAccessToken,
)
)
// PermissionBasedGranter is a helper based struct that provide common permission based implementations
type PermissionBasedGranter struct {
authenticator security.Authenticator
}
func (g *PermissionBasedGranter) authenticateToken(ctx context.Context, request *auth.TokenRequest) (oauth2.Authentication, error) {
tokenValue, ok := request.Extensions[oauth2.ParameterAccessToken].(string)
if !ok {
return nil, oauth2.NewInvalidTokenRequestError("access_token is missing")
}
candidate := tokenauth.BearerToken{
Token: tokenValue,
DetailsMap: map[string]interface{}{},
}
// Authenticate
auth, e := g.authenticator.Authenticate(ctx, &candidate)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
oauth, ok := auth.(oauth2.Authentication)
switch {
case !ok:
fallthrough
case oauth.State() < security.StateAuthenticated:
return nil, oauth2.NewInvalidGrantError("invalid access token", e)
case oauth.UserAuthentication() == nil || oauth.UserAuthentication().State() < security.StateAuthenticated:
return nil, oauth2.NewInvalidGrantError("access token is not associated with a valid user")
}
return oauth, nil
}
func (g *PermissionBasedGranter) validateStoredPermissions(ctx context.Context, stored security.Authentication, permissions ...string) error {
perms := stored.Permissions()
if perms == nil {
return oauth2.NewInvalidGrantError("user has no permissions")
}
for _, p := range permissions {
if _, ok := perms[p]; !ok {
return oauth2.NewInvalidGrantError(fmt.Sprintf("user doesn't have required permission [%s]", p))
}
}
return nil
}
// Expectation is that only users with appropriate VIEW_OPERATOR_LOGIN_AS_CUSTOMER and
// SWITCH_TENANT permissions along with appropriate grant type are allowed to perform the security context
// switch. This enforcement is done in other parts of the security context switch flow.
func (g *PermissionBasedGranter) validateStoredClient(ctx context.Context, client oauth2.OAuth2Client, src oauth2.OAuth2Request) error {
original := src.ClientId()
requested := client.ClientId()
if original != requested {
return oauth2.NewInvalidGrantError(fmt.Sprintf("security context switch as original Client ID [%s] and requesting Client ID [%s]", original, requested))
}
return nil
}
// Since we don't require requesting clientId to be same as original clientId, we have to also check
// original scope and requested scope. Ideally, when requesting clientId is different from original clientId,
// scopes should be re-authorized by user if it changed. However, since we always uses auto-approve,
// we could skip this step as long as all requested scope are auto approve.
//
// New scopes should be copied from either original request (if no "scope" param) or the token request.
// in both cases, they need to be validated against current client
func (g *PermissionBasedGranter) reduceScope(ctx context.Context, client oauth2.OAuth2Client,
src oauth2.OAuth2Request, request *auth.TokenRequest) (oauth2.OAuth2Request, error) {
original := src.Scopes()
scopes := request.Scopes
if scopes == nil || len(scopes) == 0 {
scopes = original
}
if client.ClientId() != src.ClientId() {
// we are dealing with different client, all scopes need to be re-validated against current client
if e := auth.ValidateAllScopes(ctx, client, scopes); e != nil {
return nil, e
}
if e := auth.ValidateAllAutoApprovalScopes(ctx, client, scopes); e != nil {
return nil, e
}
} else {
// same client, we only check if 1. new scope is a subset of original, OR 2. all new scopes are auto approved
for scope, _ := range scopes {
if !original.Has(scope) && !client.AutoApproveScopes().Has(scope) {
return nil, oauth2.NewInvalidScopeError(fmt.Sprintf("scope [%s] is not allowed by this client", scope))
}
}
}
return src.NewOAuth2Request(func(opt *oauth2.RequestDetails) {
opt.ClientId = client.ClientId()
opt.RedirectUri = ""
opt.GrantType = request.GrantType
opt.Scopes = scopes
for k, v := range request.Parameters {
if permissionBasedIgnoreParams.Has(k) {
continue
}
opt.Parameters[k] = v
}
for k, v := range request.Extensions {
if permissionBasedIgnoreParams.Has(k) {
continue
}
opt.Extensions[k] = v
}
}), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"crypto"
"encoding/base64"
"fmt"
"strings"
)
const (
PKCEChallengeMethodPlain PKCECodeChallengeMethod = "plain"
PKCEChallengeMethodSHA256 PKCECodeChallengeMethod = "S256"
)
type PKCECodeChallengeMethod string
func (m *PKCECodeChallengeMethod) UnmarshalText(text []byte) error {
str := string(text)
switch {
case string(PKCEChallengeMethodPlain) == strings.ToLower(str):
*m = PKCEChallengeMethodPlain
case string(PKCEChallengeMethodSHA256) == strings.ToUpper(str):
*m = PKCEChallengeMethodSHA256
case len(text) == 0:
*m = PKCEChallengeMethodPlain
default:
return fmt.Errorf("invalid code challenge method")
}
return nil
}
func parseCodeChallengeMethod(str string) (ret PKCECodeChallengeMethod, err error) {
err = ret.UnmarshalText([]byte(str))
return
}
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
func verifyPKCE(toVerify string, challenge string, method PKCECodeChallengeMethod) (ret bool) {
var encoded string
switch method {
case PKCEChallengeMethodPlain:
encoded = toVerify
case PKCEChallengeMethodSHA256:
hash := crypto.SHA256.New()
if _, e := hash.Write([]byte(toVerify)); e != nil {
return
}
encoded = base64.RawURLEncoding.EncodeToString(hash.Sum(nil))
default:
return
}
return encoded == challenge
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/utils"
)
var (
refreshIgnoreParams = utils.NewStringSet(
oauth2.ParameterClientSecret,
oauth2.ParameterRefreshToken,
)
)
// RefreshGranter implements auth.TokenGranter
type RefreshGranter struct {
authService auth.AuthorizationService
tokenStore auth.TokenStore
}
func NewRefreshGranter(authService auth.AuthorizationService, tokenStore auth.TokenStore) *RefreshGranter {
if authService == nil {
panic(fmt.Errorf("cannot create AuthorizationCodeGranter without auth service."))
}
return &RefreshGranter{
authService: authService,
tokenStore: tokenStore,
}
}
func (g *RefreshGranter) Grant(ctx context.Context, request *auth.TokenRequest) (oauth2.AccessToken, error) {
if oauth2.GrantTypeRefresh != request.GrantType {
return nil, nil
}
// for refresh grant, client have to be authenticated via client/secret
client := auth.RetrieveAuthenticatedClient(ctx)
if client == nil {
return nil, oauth2.NewInvalidGrantError("client_credentials requires client secret validated")
}
// common check
if e := CommonPreGrantValidation(ctx, client, request); e != nil {
return nil, e
}
// extract refresh token
refresh, ok := request.Extensions[oauth2.ParameterRefreshToken].(string)
if !ok || refresh == "" {
return nil, oauth2.NewInvalidTokenRequestError(fmt.Sprintf("missing required parameter %s", oauth2.ParameterRefreshToken))
}
refreshToken, e := g.tokenStore.ReadRefreshToken(ctx, refresh)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
} else if refreshToken.WillExpire() && refreshToken.Expired() {
_ = g.tokenStore.RemoveRefreshToken(ctx, refreshToken)
return nil, oauth2.NewInvalidGrantError("refresh token expired")
}
// load stored authentication
stored, e := g.tokenStore.ReadAuthentication(ctx, refresh, oauth2.TokenHintRefreshToken)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
// validate stored authentication
// check client ID
if stored.OAuth2Request().ClientId() != client.ClientId() {
return nil, oauth2.NewInvalidGrantError("client ID mismatch")
}
// reduced scope
oauthRequest, e := reduceScope(ctx, client, stored.OAuth2Request(), request)
if e != nil {
return nil, e
}
// construct auth
// Note: user's authentication/details should be reloaded and re-verified in this process.
oauth, e := g.authService.CreateAuthentication(ctx, oauthRequest, stored.UserAuthentication())
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
// create token
token, e := g.authService.RefreshAccessToken(ctx, oauth, refreshToken)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
return token, nil
}
func reduceScope(c context.Context, client oauth2.OAuth2Client, src oauth2.OAuth2Request, request *auth.TokenRequest) (oauth2.OAuth2Request, error) {
if !src.Approved() {
return nil, oauth2.NewInvalidGrantError("original OAuth2 request was not approved")
}
if request.Scopes == nil || len(request.Scopes) == 0 {
// didn't request scope reduction. bail
return src, nil
}
// we double check if all requested scopes are authorized
if e := auth.ValidateAllScopes(c, client, request.Scopes); e != nil {
return nil, e
}
if ok, invalid := auth.IsSubSet(c, src.Scopes(), request.Scopes); !ok {
return nil, oauth2.NewInvalidScopeError(fmt.Sprintf("scope [%s] was not originally authorized", invalid))
}
return src.NewOAuth2Request(func(opt *oauth2.RequestDetails) {
opt.GrantType = request.GrantType
opt.Scopes = request.Scopes
for k, v := range request.Parameters {
if refreshIgnoreParams.Has(k) {
continue
}
opt.Parameters[k] = v
}
for k, v := range request.Extensions {
if refreshIgnoreParams.Has(k) {
continue
}
opt.Extensions[k] = v
}
}), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"strings"
)
var (
switchTenantPermissions = []string{
security.SpecialPermissionSwitchTenant,
}
)
// SwitchTenantGranter implements auth.TokenGranter
type SwitchTenantGranter struct {
PermissionBasedGranter
accountStore security.AccountStore
authService auth.AuthorizationService
}
func NewSwitchTenantGranter(
authService auth.AuthorizationService,
authenticator security.Authenticator,
accountStore security.AccountStore,
) *SwitchTenantGranter {
if authenticator == nil {
panic(fmt.Errorf("cannot create SwitcTenantGranter without authenticator."))
}
if authService == nil {
panic(fmt.Errorf("cannot create SwitchTenantGranter without authorization service."))
}
if accountStore == nil {
panic(fmt.Errorf("cannot create SwitchTenantGranter without account store."))
}
return &SwitchTenantGranter{
PermissionBasedGranter: PermissionBasedGranter{
authenticator: authenticator,
},
authService: authService,
accountStore: accountStore,
}
}
func (g *SwitchTenantGranter) Grant(ctx context.Context, request *auth.TokenRequest) (oauth2.AccessToken, error) {
if oauth2.GrantTypeSwitchTenant != request.GrantType {
return nil, nil
}
client := auth.RetrieveAuthenticatedClient(ctx)
// common check
if e := auth.ValidateGrant(ctx, client, request.GrantType); e != nil {
return nil, e
}
// additional request params check
if e := g.validateRequest(ctx, request); e != nil {
return nil, e
}
// extract existing auth
stored, e := g.authenticateToken(ctx, request)
if e != nil {
return nil, e
}
// check permissions
if e := g.validate(ctx, request, stored); e != nil {
return nil, e
}
// additional check
// check client details (if client ID matches, if all requested scope is allowed, etc)
if e := g.validateStoredClient(ctx, client, stored.OAuth2Request()); e != nil {
return nil, e
}
// get merged request with reduced scope
req, e := g.reduceScope(ctx, client, stored.OAuth2Request(), request)
if e != nil {
return nil, e
}
oauth, e := g.authService.SwitchAuthentication(ctx, req, stored.UserAuthentication(), stored)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
// create token
token, e := g.authService.CreateAccessToken(ctx, oauth)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
return token, nil
}
func (g *SwitchTenantGranter) validateRequest(ctx context.Context, request *auth.TokenRequest) error {
tenantId, idOk := request.Extensions[oauth2.ParameterTenantId].(string)
tenantExternalid, nameOk := request.Extensions[oauth2.ParameterTenantExternalId].(string)
if !nameOk && !idOk {
return oauth2.NewInvalidTokenRequestError(fmt.Sprintf("both [%s] and [%s] are missing", oauth2.ParameterTenantId, oauth2.ParameterTenantExternalId))
}
if strings.TrimSpace(tenantId) == "" && strings.TrimSpace(tenantExternalid) == "" {
return oauth2.NewInvalidTokenRequestError(fmt.Sprintf("both [%s] and [%s] are empty", oauth2.ParameterTenantId, oauth2.ParameterTenantExternalId))
}
return nil
}
func (g *SwitchTenantGranter) validate(ctx context.Context, request *auth.TokenRequest, stored security.Authentication) error {
if e := g.PermissionBasedGranter.validateStoredPermissions(ctx, stored, switchTenantPermissions...); e != nil {
return e
}
if proxy, ok := stored.Details().(security.ProxiedUserDetails); ok && proxy.Proxied() {
return oauth2.NewInvalidGrantError("the access token represents a masqueraded context. need original token to switch tenant")
}
srcTenant, ok := stored.Details().(security.TenantDetails)
if !ok {
// there wasn't any tenant. shouldn't happen, but we allow it because it won't cause any trouble
return nil
}
var tenantHasChanged bool
tenantId, _ := request.Extensions[oauth2.ParameterTenantId].(string)
if strings.TrimSpace(tenantId) != srcTenant.TenantId() {
tenantHasChanged = true
}
tenantExternalId, _ := request.Extensions[oauth2.ParameterTenantExternalId].(string)
if strings.TrimSpace(tenantExternalId) == srcTenant.TenantExternalId() {
tenantHasChanged = true
}
if !tenantHasChanged {
return oauth2.NewInvalidGrantError("cannot switch to same tenant")
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package grants
import (
"context"
"fmt"
securityinternal "github.com/cisco-open/go-lanai/internal/security"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/utils"
"strings"
)
var (
switchUserPermissions = []string{security.SpecialPermissionSwitchUser}
)
// SwitchUserGranter implements auth.TokenGranter
type SwitchUserGranter struct {
PermissionBasedGranter
authService auth.AuthorizationService
accountStore security.AccountStore
}
func NewSwitchUserGranter(authService auth.AuthorizationService, authenticator security.Authenticator, accountStore security.AccountStore) *SwitchUserGranter {
if authenticator == nil {
panic(fmt.Errorf("cannot create SwitchUserGranter without authenticator."))
}
if authService == nil {
panic(fmt.Errorf("cannot create SwitchUserGranter without authorization service."))
}
if accountStore == nil {
panic(fmt.Errorf("cannot create SwitchUserGranter without account store."))
}
return &SwitchUserGranter{
PermissionBasedGranter: PermissionBasedGranter{
authenticator: authenticator,
},
authService: authService,
accountStore: accountStore,
}
}
func (g *SwitchUserGranter) Grant(ctx context.Context, request *auth.TokenRequest) (oauth2.AccessToken, error) {
if oauth2.GrantTypeSwitchUser != request.GrantType {
return nil, nil
}
client := auth.RetrieveAuthenticatedClient(ctx)
// common check
if e := auth.ValidateGrant(ctx, client, request.GrantType); e != nil {
return nil, e
}
// additional request params check
if e := g.validateRequest(ctx, request); e != nil {
return nil, e
}
// extract existing auth
stored, e := g.authenticateToken(ctx, request)
if e != nil {
return nil, e
}
// check permissions
if e := g.validate(ctx, request, stored); e != nil {
return nil, e
}
// additional check
// check client details (if client ID matches, if all requested scope is allowed, etc)
if e := g.validateStoredClient(ctx, client, stored.OAuth2Request()); e != nil {
return nil, e
}
// get merged request with reduced scope
req, e := g.reduceScope(ctx, client, stored.OAuth2Request(), request)
if e != nil {
return nil, e
}
// get user authentication
userAuth, e := g.loadUserAuthentication(ctx, request)
if e != nil {
return nil, e
}
// create authentication
oauth, e := g.authService.SwitchAuthentication(ctx, req, userAuth, stored)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
// create token
token, e := g.authService.CreateAccessToken(ctx, oauth)
if e != nil {
return nil, oauth2.NewInvalidGrantError(e)
}
return token, nil
}
func (g *SwitchUserGranter) validateRequest(ctx context.Context, request *auth.TokenRequest) error {
// switch_username or switch_user_id need to present
// if both available, we use username
username, nameOk := request.Extensions[oauth2.ParameterSwitchUsername].(string)
userId, idOk := request.Extensions[oauth2.ParameterSwitchUserId].(string)
if !nameOk && !idOk {
return oauth2.NewInvalidTokenRequestError(fmt.Sprintf("both [%s] and [%s] are missing", oauth2.ParameterSwitchUsername, oauth2.ParameterSwitchUserId))
}
if strings.TrimSpace(username) == "" && strings.TrimSpace(userId) == "" {
return oauth2.NewInvalidTokenRequestError(fmt.Sprintf("both [%s] and [%s] are empty", oauth2.ParameterSwitchUsername, oauth2.ParameterSwitchUserId))
}
return nil
}
func (g *SwitchUserGranter) validate(ctx context.Context, request *auth.TokenRequest, stored security.Authentication) error {
if e := g.PermissionBasedGranter.validateStoredPermissions(ctx, stored, switchUserPermissions...); e != nil {
return e
}
srcTenants, ok := stored.Details().(securityinternal.TenantAccessDetails)
if !ok {
return oauth2.NewInvalidGrantError("access token is not associated with a list of tenants")
}
if !canAccessAllTenants(ctx, srcTenants.EffectiveAssignedTenantIds()) {
return oauth2.NewInvalidGrantError("user needs to be able to access all tenants to switch user")
}
srcUser, ok := stored.Details().(security.UserDetails)
if !ok {
return oauth2.NewInvalidGrantError("access token is not associated with a valid user")
}
if proxy, ok := stored.Details().(security.ProxiedUserDetails); ok && proxy.Proxied() {
return oauth2.NewInvalidGrantError("the access token represents a masqueraded context. Nested masquerading is not supported")
}
username, _ := request.Extensions[oauth2.ParameterSwitchUsername].(string)
if strings.TrimSpace(username) == srcUser.Username() {
return oauth2.NewInvalidGrantError("cannot switch to same user")
}
userId, _ := request.Extensions[oauth2.ParameterSwitchUserId].(string)
if strings.TrimSpace(userId) == srcUser.UserId() {
return oauth2.NewInvalidGrantError("cannot switch to same user")
}
return nil
}
func (g *SwitchUserGranter) loadUserAuthentication(ctx context.Context, request *auth.TokenRequest) (security.Authentication, error) {
username, _ := request.Extensions[oauth2.ParameterSwitchUsername].(string)
userId, _ := request.Extensions[oauth2.ParameterSwitchUserId].(string)
username = strings.TrimSpace(username)
userId = strings.TrimSpace(userId)
var account security.Account
var e error
// we prefer username over userId
switch {
case username != "":
if account, e = g.accountStore.LoadAccountByUsername(ctx, username); e != nil {
return nil, oauth2.NewInvalidGrantError(fmt.Sprintf("invalid %s [%s]", oauth2.ParameterSwitchUsername, username), e)
}
default:
if account, e = g.accountStore.LoadAccountById(ctx, userId); e != nil {
return nil, oauth2.NewInvalidGrantError(fmt.Sprintf("invalid %s [%s]", oauth2.ParameterSwitchUserId, userId), e)
}
}
permissions := map[string]interface{}{}
for _, v := range account.Permissions() {
permissions[v] = true
}
return oauth2.NewUserAuthentication(func(opt *oauth2.UserAuthOption) {
opt.Principal = account.Username()
opt.Permissions = permissions
opt.State = security.StateAuthenticated
}), nil
}
func canAccessAllTenants(ctx context.Context, assignedTenantIds utils.StringSet) bool {
if assignedTenantIds.Has(security.SpecialTenantIdWildcard) {
return true
}
rootId, err := tenancy.GetRoot(ctx)
if err != nil || rootId == "" {
return false
}
return assignedTenantIds.Has(rootId)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package misc
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/tokenauth"
"github.com/cisco-open/go-lanai/pkg/utils"
)
const (
msgInvalidTokenType = "unsupported token type"
msgInvalidToken = "token is invalid or expired"
hintAccessToken = "access_token"
hintRefreshToken = "refresh_token"
)
type CheckTokenRequest struct {
Token string `form:"token"`
Hint string `form:"token_type_hint"`
NoDetails bool `form:"no_details"`
}
type CheckTokenEndpoint struct {
issuer security.Issuer
authenticator security.Authenticator
tokenStoreReader oauth2.TokenStoreReader
}
func NewCheckTokenEndpoint(issuer security.Issuer, tokenStoreReader oauth2.TokenStoreReader) *CheckTokenEndpoint {
authenticator := tokenauth.NewAuthenticator(func(opt *tokenauth.AuthenticatorOption) {
opt.TokenStoreReader = tokenStoreReader
})
return &CheckTokenEndpoint{
issuer: issuer,
authenticator: authenticator,
tokenStoreReader: tokenStoreReader,
}
}
// CheckToken is the token introspection end point as defined in https://datatracker.ietf.org/doc/html/rfc7662
// This endpoint is used by protected resources to query the authorization server to determine the state and metadata of a token.
// Because this request is issued by a protected resource, the client used by the protected resource is not going to be the same
// as the client the token is issued for.
// The auth server require the protected resource to be specifically authorized to call this endpoint by means of client authentication
// and client scope (token_details).
// This end point is not meant to be used for other means. Any client that's not a protected resource should not be given this scope.
func (ep *CheckTokenEndpoint) CheckToken(c context.Context, request *CheckTokenRequest) (response *CheckTokenClaims, err error) {
client := auth.RetrieveAuthenticatedClient(c)
if client == nil {
return nil, oauth2.NewInvalidClientError("check token endpoint requires client authentication")
}
switch request.Hint {
case "":
fallthrough
case hintAccessToken:
if request.NoDetails || !ep.allowDetails(c, client) {
return ep.checkAccessTokenWithoutDetails(c, request)
}
return ep.checkAccessTokenWithDetails(c, request)
case hintRefreshToken:
return ep.checkRefreshToken(c, request)
default:
return nil, oauth2.NewUnsupportedTokenTypeError(fmt.Sprintf("token_type_hint '%s' is not supported", request.Hint))
}
}
func (ep *CheckTokenEndpoint) allowDetails(_ context.Context, client oauth2.OAuth2Client) bool {
return client.Scopes() != nil && client.Scopes().Has(oauth2.ScopeTokenDetails)
}
func (ep *CheckTokenEndpoint) checkAccessTokenWithoutDetails(c context.Context, request *CheckTokenRequest) (response *CheckTokenClaims, err error) {
token, e := ep.tokenStoreReader.ReadAccessToken(c, request.Token)
if e != nil || token == nil || token.Expired() {
//nolint:nilerr // we hide error in response and returns compliant response
return ep.inactiveTokenResponse(), nil
}
return ep.activeTokenResponseWithoutDetails(), nil
}
func (ep *CheckTokenEndpoint) checkAccessTokenWithDetails(c context.Context, request *CheckTokenRequest) (response *CheckTokenClaims, err error) {
candidate := tokenauth.BearerToken{
Token: request.Token,
DetailsMap: map[string]interface{}{},
}
oauth, e := ep.authenticator.Authenticate(c, &candidate)
if e != nil || oauth.State() < security.StateAuthenticated {
//nolint:nilerr // we hide error in response and returns compliant response
return ep.inactiveTokenResponse(), nil
}
return ep.activeTokenResponseWithDetails(c, oauth.(oauth2.Authentication)), nil
}
func (ep *CheckTokenEndpoint) checkRefreshToken(_ context.Context, request *CheckTokenRequest) (response *CheckTokenClaims, err error) {
// We don't support refresh token check for now
return nil, oauth2.NewUnsupportedTokenTypeError(fmt.Sprintf("token_type_hint '%s' is not supported", request.Hint))
}
func (ep *CheckTokenEndpoint) inactiveTokenResponse() *CheckTokenClaims {
return &CheckTokenClaims{
Active: &utils.FALSE,
}
}
func (ep *CheckTokenEndpoint) activeTokenResponseWithoutDetails() *CheckTokenClaims {
return &CheckTokenClaims{
Active: &utils.TRUE,
}
}
func (ep *CheckTokenEndpoint) activeTokenResponseWithDetails(ctx context.Context, auth oauth2.Authentication) *CheckTokenClaims {
c := CheckTokenClaims{
Active: &utils.TRUE,
}
e := claims.Populate(ctx, &c,
claims.WithSpecs(claims.CheckTokenClaimSpecs),
claims.WithSource(auth),
claims.WithIssuer(ep.issuer))
if e != nil {
return ep.activeTokenResponseWithoutDetails()
}
return &c
}
// Old impl. without claims factory, for reference only
//func (ep *CheckTokenEndpoint) activeTokenResponseWithDetails(auth oauth2.Authentication) *CheckTokenClaims {
// claims := CheckTokenClaims{
// Active: &utils.TRUE,
// BasicClaims: oauth2.BasicClaims{
// Audience: auth.OAuth2Request().ClientId(),
// ExpiresAt: auth.Details().(security.ContextDetails).ExpiryTime(),
// //Id: auth.AccessToken().Id,
// IssuedAt: auth.Details().(security.ContextDetails).IssueTime(),
// //Issuer: auth.AccessToken(),
// //NotBefore: auth.AccessToken(),
// Subject: auth.UserAuthentication().Principal().(string),
// Scopes: auth.OAuth2Request().Scopes(),
// ClientId: auth.OAuth2Request().ClientId(),
// },
// Username: auth.UserAuthentication().Principal().(string),
// AuthTime: auth.Details().(security.ContextDetails).AuthenticationTime(),
// FirstName: auth.Details().(security.UserDetails).FirstName(),
// LastName: auth.Details().(security.UserDetails).LastName(),
// Email: auth.Details().(security.UserDetails).Email(),
// Locale: auth.Details().(security.UserDetails).LocaleCode(),
//
// UserId: auth.Details().(security.UserDetails).UserId(),
// AccountType: auth.Details().(security.UserDetails).AccountType().String(),
// Currency: auth.Details().(security.UserDetails).CurrencyCode(),
// AssignedTenants: auth.Details().(security.UserDetails).AssignedTenantIds(),
// TenantId: auth.Details().(security.TenantDetails).TenantId(),
// TenantExternalId: auth.Details().(security.TenantDetails).TenantExternalId(),
// TenantSuspended: utils.BoolPtr(auth.Details().(security.TenantDetails).TenantSuspended()),
// ProviderId: auth.Details().(security.ProviderDetails).ProviderId(),
// ProviderName: auth.Details().(security.ProviderDetails).ProviderName(),
// Roles: auth.Details().(security.ContextDetails).Roles(),
// Permissions: auth.Details().(security.ContextDetails).Permissions(),
// OrigUsername: auth.Details().(security.ProxiedUserDetails).OriginalUsername(),
// }
//
// return &claims
//}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package misc
import (
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
// CheckTokenClaims implemnts oauth2.Claims
type CheckTokenClaims struct {
oauth2.FieldClaimsMapper
/*******************************
* Standard Check Token claims
*******************************/
oauth2.BasicClaims
Active *bool `claim:"active"`
Username string `claim:"username"`
/*******************************
* Standard OIDC claims
*******************************/
FirstName string `claim:"given_name"`
LastName string `claim:"family_name"`
Email string `claim:"email"`
Locale string `claim:"locale"` // Typically ISO 639-1 Alpha-2 [ISO639‑1] language code in lowercase and an ISO 3166-1
AuthTime time.Time `claim:"auth_time"`
/*******************************
* NFV Additional Claims
*******************************/
UserId string `claim:"user_id"`
AccountType string `claim:"account_type"`
Currency string `claim:"currency"`
TenantId string `claim:"tenant_id"`
TenantExternalId string `claim:"tenant_name"` //This maps to Tenant's ExternalId for backward compatibility
TenantSuspended *bool `claim:"tenant_suspended"`
ProviderId string `claim:"provider_id"`
ProviderName string `claim:"provider_name"`
ProviderDisplayName string `claim:"provider_display_name"`
ProviderDescription string `claim:"provider_description"`
ProviderNotificationType string `claim:"provider_notification_type"`
ProviderEmail string `claim:"provider_email"`
AssignedTenants utils.StringSet `claim:"assigned_tenants"`
Roles utils.StringSet `claim:"roles"`
Permissions utils.StringSet `claim:"permissions"`
OrigUsername string `claim:"original_username"`
}
func (c *CheckTokenClaims) MarshalJSON() ([]byte, error) {
return c.FieldClaimsMapper.DoMarshalJSON(c)
}
func (c *CheckTokenClaims) UnmarshalJSON(bytes []byte) error {
return c.FieldClaimsMapper.DoUnmarshalJSON(c, bytes)
}
func (c *CheckTokenClaims) Get(claim string) interface{} {
return c.FieldClaimsMapper.Get(c, claim)
}
func (c *CheckTokenClaims) Has(claim string) bool {
return c.FieldClaimsMapper.Has(c, claim)
}
func (c *CheckTokenClaims) Set(claim string, value interface{}) {
c.FieldClaimsMapper.Set(c, claim, value)
}
func (c *CheckTokenClaims) Values() map[string]interface{} {
return c.FieldClaimsMapper.Values(c)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package misc
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/web"
"net/http"
)
const (
JwkTypeRSA = "RSA"
)
type JwkSetRequest struct {
Kid string `uri:"kid"`
}
type JwkSetResponse struct {
Keys []jwt.Jwk `json:"keys"`
}
type JwkSetEndpoint struct {
jwkStore jwt.JwkStore
}
func NewJwkSetEndpoint(jwkStore jwt.JwkStore) *JwkSetEndpoint {
return &JwkSetEndpoint{
jwkStore: jwkStore,
}
}
func (ep *JwkSetEndpoint) JwkByKid(ctx context.Context, req *JwkSetRequest) (resp jwt.Jwk, err error) {
jwk, e := ep.jwkStore.LoadByKid(ctx, req.Kid)
if e != nil {
return nil, web.NewHttpError(http.StatusNotFound, e)
}
return jwk, nil
}
func (ep *JwkSetEndpoint) JwkSet(ctx context.Context, _ *JwkSetRequest) (resp *JwkSetResponse, err error) {
jwks, e := ep.jwkStore.LoadAll(ctx)
if e != nil {
return nil, oauth2.NewGenericError(e.Error())
}
resp = &JwkSetResponse{
Keys: jwks,
}
return resp, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package misc
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/web"
)
type TenantHierarchyEndpoint struct {
}
func NewTenantHierarchyEndpoint() *TenantHierarchyEndpoint {
return &TenantHierarchyEndpoint{}
}
type HierarchyRequest struct {
TenantId string `form:"tenantId"`
}
func (endpoint *TenantHierarchyEndpoint) GetParent(ctx context.Context, req *HierarchyRequest) (string, error) {
if allow, err := allowAccess(ctx); !allow {
return "", err
}
p, err := tenancy.GetParent(ctx, req.TenantId)
if err != nil {
return "", err
} else {
return p, err
}
}
func (endpoint *TenantHierarchyEndpoint) GetChildren(ctx context.Context, req *HierarchyRequest) (interface{}, error) {
if allow, err := allowAccess(ctx); !allow {
return nil, err
}
children, err := tenancy.GetChildren(ctx, req.TenantId)
if err == nil {
ret := children
return ret, nil
} else {
return nil, err
}
}
func (endpoint *TenantHierarchyEndpoint) GetAncestors(ctx context.Context, req *HierarchyRequest) (interface{}, error) {
if allow, err := allowAccess(ctx); !allow {
return nil, err
}
ancestor, err := tenancy.GetAncestors(ctx, req.TenantId)
if err == nil {
return ancestor, nil
} else {
return nil, err
}
}
func (endpoint *TenantHierarchyEndpoint) GetDescendants(ctx context.Context, req *HierarchyRequest) (interface{}, error) {
if allow, err := allowAccess(ctx); !allow {
return nil, err
}
descendants, err := tenancy.GetDescendants(ctx, req.TenantId)
if err == nil {
ret := descendants
return ret, nil
} else {
return nil, err
}
}
func (endpoint *TenantHierarchyEndpoint) GetRoot(ctx context.Context, _ *web.EmptyRequest) (string, error) {
if allow, err := allowAccess(ctx); !allow {
return "", err
}
root, err := tenancy.GetRoot(ctx)
if err != nil {
return "", err
} else {
return root, nil
}
}
func allowAccess(ctx context.Context) (bool, error) {
client := auth.RetrieveAuthenticatedClient(ctx)
if client == nil {
return false, oauth2.NewInvalidClientError("tenant hierarchy endpoint requires client authentication")
}
if !client.Scopes().Has(oauth2.ScopeTenantHierarchy) {
return false, oauth2.NewInsufficientScopeError("tenant hierarchy endpoint requires tenant_hierarchy scope")
}
return true, nil
}
func StringResponseEncoder() web.EncodeResponseFunc {
return web.CustomResponseEncoder(func(opt *web.EncodeOption) {
opt.ContentType = "application/json; charset=utf-8"
opt.WriteFunc = web.TextWriteFunc
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package misc
import (
"context"
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/openid"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/web"
)
var (
scopedSpecs = map[string]map[string]claims.ClaimSpec{
oauth2.ScopeOidcProfile: claims.ProfileScopeSpecs,
oauth2.ScopeOidcEmail: claims.EmailScopeSpecs,
oauth2.ScopeOidcPhone: claims.PhoneScopeSpecs,
oauth2.ScopeOidcAddress: claims.AddressScopeSpecs,
}
defaultSpecs = []map[string]claims.ClaimSpec{
claims.UserInfoBasicSpecs,
}
fullSpecs = []map[string]claims.ClaimSpec{
claims.UserInfoBasicSpecs,
claims.ProfileScopeSpecs,
claims.EmailScopeSpecs,
claims.PhoneScopeSpecs,
claims.AddressScopeSpecs,
}
)
type UserInfoRequest struct{}
type UserInfoPlainResponse struct {
UserInfoClaims
}
type UserInfoJwtResponse string
// MarshalText implements encoding.TextMarshaler
func (r UserInfoJwtResponse) MarshalText() (text []byte, err error) {
return []byte(r), nil
}
type UserInfoEndpoint struct {
issuer security.Issuer
accountStore security.AccountStore
jwtEncoder jwt.JwtEncoder
}
func NewUserInfoEndpoint(issuer security.Issuer, accountStore security.AccountStore, jwtEncoder jwt.JwtEncoder) *UserInfoEndpoint {
return &UserInfoEndpoint{
issuer: issuer,
accountStore: accountStore,
jwtEncoder: jwtEncoder,
}
}
func (ep *UserInfoEndpoint) PlainUserInfo(ctx context.Context, _ UserInfoRequest) (resp *UserInfoPlainResponse, err error) {
auth, ok := security.Get(ctx).(oauth2.Authentication)
if !ok || auth.UserAuthentication() == nil {
return nil, oauth2.NewAccessRejectedError("missing user authentication")
}
specs := ep.determineClaimSpecs(auth.OAuth2Request())
requested := ep.determineRequestedClaims(auth.OAuth2Request())
c := UserInfoClaims{}
e := claims.Populate(ctx, &c,
claims.WithSpecs(specs...),
claims.WithSource(auth),
claims.WithIssuer(ep.issuer),
claims.WithAccountStore(ep.accountStore),
claims.WithRequestedClaims(requested, fullSpecs...),
)
if e != nil {
return nil, oauth2.NewInternalError(e)
}
return &UserInfoPlainResponse{
UserInfoClaims: c,
}, nil
}
func (ep *UserInfoEndpoint) JwtUserInfo(ctx context.Context, _ UserInfoRequest) (resp UserInfoJwtResponse, err error) {
auth, ok := security.Get(ctx).(oauth2.Authentication)
if !ok || auth.UserAuthentication() == nil {
return "", oauth2.NewAccessRejectedError("missing user authentication")
}
c := UserInfoClaims{}
e := claims.Populate(ctx, &c,
claims.WithSpecs(
claims.UserInfoBasicSpecs,
claims.ProfileScopeSpecs,
claims.EmailScopeSpecs,
claims.PhoneScopeSpecs,
claims.AddressScopeSpecs,
),
claims.WithSource(auth),
claims.WithIssuer(ep.issuer),
claims.WithAccountStore(ep.accountStore),
)
if e != nil {
return "", oauth2.NewInternalError(err)
}
token, e := ep.jwtEncoder.Encode(ctx, &c)
if e != nil {
return "", oauth2.NewInternalError(e)
}
return UserInfoJwtResponse(token), nil
}
// determineClaimSpecs works slightly different from the id_token version:
// When openid scope is not in the request, full specs is given
func (ep *UserInfoEndpoint) determineClaimSpecs(request oauth2.OAuth2Request) []map[string]claims.ClaimSpec {
if request == nil || request.Scopes() == nil || !request.Approved() {
return defaultSpecs
}
if !request.Scopes().Has(oauth2.ScopeOidc) {
return fullSpecs
}
specs := make([]map[string]claims.ClaimSpec, len(defaultSpecs), len(defaultSpecs)+len(request.Scopes()))
for i, spec := range defaultSpecs {
specs[i] = spec
}
scopes := request.Scopes()
for scope, spec := range scopedSpecs {
if scopes.Has(scope) {
specs = append(specs, spec)
}
}
return specs
}
func (ep *UserInfoEndpoint) determineRequestedClaims(request oauth2.OAuth2Request) claims.RequestedClaims {
raw, ok := request.Extensions()[oauth2.ParameterClaims].(string)
if !ok {
return nil
}
cr := openid.ClaimsRequest{}
if e := json.Unmarshal([]byte(raw), &cr); e != nil {
return nil
}
return cr.UserInfo
}
func JwtResponseEncoder() web.EncodeResponseFunc {
return web.CustomResponseEncoder(func(opt *web.EncodeOption) {
opt.ContentType = "application/jwt; charset=utf-8"
opt.WriteFunc = web.TextWriteFunc
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package misc
import (
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
type UserInfoClaims struct {
oauth2.FieldClaimsMapper
/*******************************
* Standard JWT claims
*******************************/
Issuer string `claim:"iss"`
Audience oauth2.StringSetClaim `claim:"aud"`
Subject string `claim:"sub"`
/*******************************
* Standard OIDC claims
*******************************/
FullName string `claim:"name"`
FirstName string `claim:"given_name"`
LastName string `claim:"family_name"`
MiddleName string `claim:"middle_name"`
Nickname string `claim:"nickname"`
PreferredUsername string `claim:"preferred_username"`
ProfileUrl string `claim:"profile"`
PictureUrl string `claim:"picture"`
Website string `claim:"website"`
Email string `claim:"email"`
EmailVerified *bool `claim:"email_verified"`
Gender string `claim:"gender"`
Birthday string `claim:"birthdate"` // ISO 8601:2004 [ISO8601‑2004] YYYY-MM-DD format
ZoneInfo string `claim:"zoneinfo"` // Europe/Paris or America/Los_Angeles
Locale string `claim:"locale"` // Typically ISO 639-1 Alpha-2 [ISO639‑1] language code in lowercase and an ISO 3166-1
PhoneNumber string `claim:"phone_number"` // RFC 3966 [RFC3966] e.g. +1 (604) 555-1234;ext=5678
PhoneNumVerified *bool `claim:"phone_number_verified"`
Address *claims.AddressClaim `claim:"address"`
UpdatedAt time.Time `claim:"updated_at"`
/*******************************
* NFV Additional Claims
*******************************/
AccountType string `claim:"account_type"`
DefaultTenantId string `claim:"default_tenant_id"`
AssignedTenants utils.StringSet `claim:"assigned_tenants"`
Roles utils.StringSet `claim:"roles"`
Permissions utils.StringSet `claim:"permissions"`
}
func (c UserInfoClaims) MarshalJSON() ([]byte, error) {
return c.FieldClaimsMapper.DoMarshalJSON(c)
}
func (c *UserInfoClaims) UnmarshalJSON(bytes []byte) error {
return c.FieldClaimsMapper.DoUnmarshalJSON(c, bytes)
}
func (c UserInfoClaims) Get(claim string) interface{} {
return c.FieldClaimsMapper.Get(c, claim)
}
func (c UserInfoClaims) Has(claim string) bool {
return c.FieldClaimsMapper.Has(c, claim)
}
func (c *UserInfoClaims) Set(claim string, value interface{}) {
c.FieldClaimsMapper.Set(c, claim, value)
}
func (c UserInfoClaims) Values() map[string]interface{} {
return c.FieldClaimsMapper.Values(c)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package misc
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/idp"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/openid"
"net/http"
)
// WellKnownEndpoint provide "/.well-known/**" HTTP endpoints
type WellKnownEndpoint struct {
issuer security.Issuer
extra map[string]interface{}
}
func NewWellKnownEndpoint(issuer security.Issuer, idpManager idp.IdentityProviderManager, extra map[string]interface{}) *WellKnownEndpoint {
if extra == nil {
extra = map[string]interface{}{}
}
extra[openid.OPMetaExtraSourceIDPManager] = idpManager
return &WellKnownEndpoint{
issuer: issuer,
extra: extra,
}
}
// OpenIDConfig should mapped to GET /.well-known/openid-configuration
func (ep *WellKnownEndpoint) OpenIDConfig(ctx context.Context, _ *http.Request) (resp *openid.OPMetadata, err error) {
c := openid.OPMetadata{MapClaims: oauth2.MapClaims{}}
e := claims.Populate(ctx, &c,
claims.WithSpecs(openid.OPMetadataBasicSpecs, openid.OPMetadataOptionalSpecs),
claims.WithIssuer(ep.issuer),
claims.WithExtraSource(ep.extra),
)
if e != nil {
return nil, e
}
return &c, nil
}
// Code generated by mockery v2.20.0. DO NOT EDIT.
package mocks
import (
context "context"
auth "github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
mock "github.com/stretchr/testify/mock"
)
// AccessRevoker is an autogenerated mock type for the AccessRevoker type
type AccessRevoker struct {
mock.Mock
}
// RevokeWithClientId provides a mock function with given fields: ctx, clientId, revokeRefreshToken
func (_m *AccessRevoker) RevokeWithClientId(ctx context.Context, clientId string, revokeRefreshToken bool) error {
ret := _m.Called(ctx, clientId, revokeRefreshToken)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok {
r0 = rf(ctx, clientId, revokeRefreshToken)
} else {
r0 = ret.Error(0)
}
return r0
}
// RevokeWithSessionId provides a mock function with given fields: ctx, sessionId, sessionName
func (_m *AccessRevoker) RevokeWithSessionId(ctx context.Context, sessionId string, sessionName string) error {
ret := _m.Called(ctx, sessionId, sessionName)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, sessionId, sessionName)
} else {
r0 = ret.Error(0)
}
return r0
}
// RevokeWithTokenValue provides a mock function with given fields: ctx, tokenValue, hint
func (_m *AccessRevoker) RevokeWithTokenValue(ctx context.Context, tokenValue string, hint auth.RevokerTokenHint) error {
ret := _m.Called(ctx, tokenValue, hint)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, auth.RevokerTokenHint) error); ok {
r0 = rf(ctx, tokenValue, hint)
} else {
r0 = ret.Error(0)
}
return r0
}
// RevokeWithUsername provides a mock function with given fields: ctx, username, revokeRefreshToken
func (_m *AccessRevoker) RevokeWithUsername(ctx context.Context, username string, revokeRefreshToken bool) error {
ret := _m.Called(ctx, username, revokeRefreshToken)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok {
r0 = rf(ctx, username, revokeRefreshToken)
} else {
r0 = ret.Error(0)
}
return r0
}
type mockConstructorTestingTNewAccessRevoker interface {
mock.TestingT
Cleanup(func())
}
// NewAccessRevoker creates a new instance of AccessRevoker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewAccessRevoker(t mockConstructorTestingTNewAccessRevoker) *AccessRevoker {
mock := &AccessRevoker{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package openid
import (
"context"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web"
"io"
"net/http"
"net/url"
"strings"
"time"
)
const (
keyPromptProcessed = "X-OIDC-PROMPT-PROCESSED"
)
var (
supportedResponseTypes = utils.NewStringSet("id_token", "token", "code")
)
// OpenIDAuthorizeRequestProcessor implements ChainedAuthorizeRequestProcessor and order.Ordered
// it validate auth request against standard oauth2 specs
//goland:noinspection GoNameStartsWithPackageName
type OpenIDAuthorizeRequestProcessor struct {
issuer security.Issuer
jwtDecoder jwt.JwtDecoder
fallbackJwtDecoder jwt.JwtDecoder
}
type ARPOptions func(opt *ARPOption)
type ARPOption struct {
Issuer security.Issuer
JwtDecoder jwt.JwtDecoder
}
func NewOpenIDAuthorizeRequestProcessor(opts ...ARPOptions) *OpenIDAuthorizeRequestProcessor {
opt := ARPOption{}
for _, f := range opts {
f(&opt)
}
return &OpenIDAuthorizeRequestProcessor{
issuer: opt.Issuer,
jwtDecoder: opt.JwtDecoder,
fallbackJwtDecoder: jwt.NewPlaintextJwtDecoder(),
}
}
func (p *OpenIDAuthorizeRequestProcessor) Process(ctx context.Context, request *auth.AuthorizeRequest, chain auth.AuthorizeRequestProcessChain) (validated *auth.AuthorizeRequest, err error) {
// first thing first, is "openid" scope requested?
if !request.Scopes.Has(oauth2.ScopeOidc) {
return chain.Next(ctx, request)
}
if e := p.validateResponseTypes(ctx, request); e != nil {
return nil, e
}
// attempt to decode from request object
if request, err = p.decodeRequestObject(ctx, request); err != nil {
return
}
// continue with the chain
if request, err = chain.Next(ctx, request); err != nil {
return
}
// additional checks
if e := p.validateImplicitFlow(ctx, request); e != nil {
return nil, e
}
if e := p.validateDisplay(ctx, request); e != nil {
return nil, e
}
cr, e := p.validateClaims(ctx, request)
if e != nil {
return nil, e
}
if e := p.validateAcrValues(ctx, request, cr); e != nil {
return nil, e
}
if e := p.processMaxAge(ctx, request); e != nil {
return nil, e
}
if e := p.processPrompt(ctx, request); e != nil {
return nil, e
}
return request, nil
}
func (p *OpenIDAuthorizeRequestProcessor) decodeRequestObject(ctx context.Context, request *auth.AuthorizeRequest) (*auth.AuthorizeRequest, error) {
reqUri, uriOk := request.Parameters[oauth2.ParameterRequestUri]
reqObj, objOk := request.Parameters[oauth2.ParameterRequestObj]
switch {
case !uriOk && !objOk:
return request, nil
case uriOk && objOk:
return nil, oauth2.NewInvalidAuthorizeRequestError(fmt.Errorf("%s and %s are exclusive", oauth2.ParameterRequestUri, oauth2.ParameterRequestObj))
case uriOk:
if strings.HasPrefix(strings.ToLower(reqUri), "https:") {
return nil, oauth2.NewInvalidAuthorizeRequestError(fmt.Errorf("%s must use https", oauth2.ParameterRequestUri))
}
bytes, e := httpGet(ctx, reqUri)
if e != nil {
return nil, oauth2.NewInvalidAuthorizeRequestError(fmt.Errorf("unable to fetch request object from %s: %v", oauth2.ParameterRequestUri, e))
}
reqObj = string(bytes)
}
// decode JWT using configured decoder, fallback to plaintext decoder
claims, e := p.jwtDecoder.Decode(ctx, reqObj)
if e != nil {
if claims, e = p.fallbackJwtDecoder.Decode(ctx, reqObj); e != nil {
return nil, oauth2.NewInvalidAuthorizeRequestError(fmt.Errorf("invalid request object: %v", e))
}
}
//nolint:contextcheck
decoded, e := claimsToAuthRequest(request.Context(), claims)
if e != nil {
return nil, oauth2.NewInvalidAuthorizeRequestError(fmt.Errorf("invalid request object: %v", e))
}
switch {
case !request.ResponseTypes.Equals(decoded.ResponseTypes):
return nil, oauth2.NewInvalidAuthorizeRequestError(fmt.Errorf("invalid request object - inconsistant response type"))
case !decoded.Scopes.Has(oauth2.ScopeOidc):
return nil, oauth2.NewInvalidAuthorizeRequestError(fmt.Errorf("invalid request object - missing 'openid' scope"))
}
return decoded, nil
}
func (p *OpenIDAuthorizeRequestProcessor) validateResponseTypes(ctx context.Context, request *auth.AuthorizeRequest) error {
return auth.ValidateResponseTypes(ctx, request, supportedResponseTypes)
}
func (p *OpenIDAuthorizeRequestProcessor) validateImplicitFlow(_ context.Context, request *auth.AuthorizeRequest) error {
if !request.ResponseTypes.Has("id_token") && !request.ResponseTypes.Has("token") {
return nil
}
// use of nonce is required when implicit flow is used without response type "code"
nonce, ok := request.Parameters[oauth2.ParameterNonce]
if !request.ResponseTypes.Has("code") && (!ok || nonce == "") {
return oauth2.NewInvalidAuthorizeRequestError("nonce is required for implicit flow")
}
return nil
}
func (p *OpenIDAuthorizeRequestProcessor) validateDisplay(ctx context.Context, request *auth.AuthorizeRequest) error {
display, ok := request.Parameters[oauth2.ParameterDisplay]
if ok && display != "" && !SupportedDisplayMode.Has(display) {
logger.WithContext(ctx).Infof("unsupported display mode [%s] was requested.", display)
}
return nil
}
// https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter
func (p *OpenIDAuthorizeRequestProcessor) validateClaims(_ context.Context, request *auth.AuthorizeRequest) (*ClaimsRequest, error) {
raw, ok := request.Parameters[oauth2.ParameterClaims]
if !ok {
return nil, nil
}
cr := ClaimsRequest{}
if e := json.Unmarshal([]byte(raw), &cr); e != nil {
// maybe we should ignore this error
return nil, oauth2.NewInvalidAuthorizeRequestError(`invalid "claims" parameter`)
}
return &cr, nil
}
func (p *OpenIDAuthorizeRequestProcessor) validateAcrValues(_ context.Context, request *auth.AuthorizeRequest, claimsReq *ClaimsRequest) error {
acrVals, ok := request.Parameters[oauth2.ParameterACR]
if !ok {
return nil
}
required := utils.NewStringSet()
optional := utils.NewStringSet(strings.Split(acrVals, " ")...)
optional.Remove("")
if claimsReq != nil && len(claimsReq.IdToken) != 0 {
if acr, ok := claimsReq.IdToken.Get(oauth2.ClaimAuthCtxClassRef); ok && !acr.IsDefault() {
if acr.Essential() {
required.Add(acr.Values()...)
} else {
optional.Add(acr.Values()...)
}
}
}
// Note, for now we only validate if required ACRs are possible, this is consistent with Java impl.
supported := utils.NewStringSet(
p.issuer.LevelOfAssurance(0),
p.issuer.LevelOfAssurance(1),
p.issuer.LevelOfAssurance(2),
)
if isMFAPossible() {
supported.Add(p.issuer.LevelOfAssurance(3))
}
// if any required ACR level is supported, we allow the request
possible := false
for lvl := range required {
if supported.Has(lvl) {
possible = true
break
}
}
if len(required) != 0 && !possible {
return oauth2.NewGranterNotAvailableError("requested acr level is not possible")
}
return nil
}
func (p *OpenIDAuthorizeRequestProcessor) processMaxAge(ctx context.Context, request *auth.AuthorizeRequest) error {
maxAgeStr, ok := request.Parameters[oauth2.ParameterMaxAge]
if !ok {
return nil
}
maxAge, e := time.ParseDuration(fmt.Sprintf("%ss", maxAgeStr))
if e != nil {
return nil //nolint:nilerr // per OpenID specs, "authroize" endpoint should simply ignore invalid request params
}
current := security.Get(ctx)
authTime := security.DetermineAuthenticationTime(ctx, current)
if !security.IsFullyAuthenticated(current) || authTime.IsZero() {
return nil
}
if authTime.Add(maxAge).Before(time.Now()) {
security.MustClear(ctx)
}
return nil
}
func (p *OpenIDAuthorizeRequestProcessor) processPrompt(ctx context.Context, request *auth.AuthorizeRequest) error {
prompt, ok := request.Parameters[oauth2.ParameterPrompt]
if !ok || prompt == "" {
return nil
}
prompts := utils.NewStringSet(strings.Split(prompt, " ")...)
// handle "none"
if prompts.Has(PromptNone) && (len(prompts) > 1 || !isCurrentlyAuthenticated(ctx)) {
return NewInteractionRequiredError("unable to authenticate without interact with user")
}
// handle "login"
// to break the login loop, we put a special header to current http request and it will be saved by request cache
if prompts.Has(PromptLogin) && !isPromptLoginProcessed(ctx) && isCurrentlyAuthenticated(ctx) {
security.MustClear(ctx)
if e := setPromptLoginProcessed(ctx); e != nil {
return NewLoginRequiredError("unable to initiate login")
}
}
// We don't support "select_account" and "consent" is checked when we have decided to show user approval
return nil
}
/*********************
Helpers
*********************/
func isCurrentlyAuthenticated(ctx context.Context) bool {
return security.IsFullyAuthenticated(security.Get(ctx))
}
func isMFAPossible() bool {
return true
}
func isPromptLoginProcessed(ctx context.Context) bool {
req := getHttpRequest(ctx)
if req == nil {
return false
}
return PromptLogin == req.Header.Get(keyPromptProcessed)
}
func setPromptLoginProcessed(ctx context.Context) error {
req := getHttpRequest(ctx)
if req == nil {
return fmt.Errorf("unable to extract http request")
}
req.Header.Set(keyPromptProcessed, PromptLogin)
return nil
}
func getHttpRequest(ctx context.Context) *http.Request {
if gc := web.GinContext(ctx); gc != nil {
return gc.Request
}
return nil
}
func httpGet(ctx context.Context, urlStr string) ([]byte, error) {
parsed, e := url.Parse(urlStr)
if e != nil {
return nil, e
}
req, e := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), http.NoBody)
if e != nil {
return nil, e
}
resp, e := http.DefaultClient.Do(req)
if e != nil {
return nil, e
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, fmt.Errorf("non-2XX status code")
}
return io.ReadAll(resp.Body)
}
func claimsToAuthRequest(ctx context.Context, claims oauth2.Claims) (*auth.AuthorizeRequest, error) {
return auth.ParseAuthorizeRequestWithKVs(ctx, claims.Values())
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package openid
import (
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
)
type ClaimsRequest struct {
UserInfo requestedClaims `json:"userinfo"`
IdToken requestedClaims `json:"id_token"`
}
// requestedClaims implements claims.RequestedClaims
type requestedClaims map[string]requestedClaim
func (r requestedClaims) Get(claim string) (c claims.RequestedClaim, ok bool) {
c, ok = r[claim]
return
}
type rcDetails struct {
Essential bool `json:"essential"`
Values []string `json:"values,omitempty"`
SingleValue *string `json:"value,omitempty"`
}
// requestedClaim implements claims.RequestedClaim and json.Unmarshaler
type requestedClaim struct {
rcDetails
}
func (r requestedClaim) Essential() bool {
return r.rcDetails.Essential
}
func (r requestedClaim) Values() []string {
return r.rcDetails.Values
}
func (r requestedClaim) IsDefault() bool {
return len(r.rcDetails.Values) == 0
}
func (r *requestedClaim) UnmarshalJSON(data []byte) error {
r.rcDetails.Values = []string{}
if e := json.Unmarshal(data, &r.rcDetails); e != nil {
return e
}
if r.rcDetails.SingleValue != nil {
r.rcDetails.Values = []string{*r.rcDetails.SingleValue}
r.rcDetails.SingleValue = nil
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package openid
import (
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
)
var logger = log.New("OpenID")
//goland:noinspection GoUnusedConst
const (
PromptNone = `none`
PromptLogin = `login`
//PromptConsent = `consent`
//PromptSelectAcct = `select_account`
)
const (
DisplayPage = `page`
PromptTouch = `touch`
//PromptPopup = `popup`
//PromptWap = `wap`
)
const (
WellKnownEndpointOPConfig = `/.well-known/openid-configuration`
)
var (
SupportedGrantTypes = utils.NewStringSet(
oauth2.GrantTypeAuthCode,
oauth2.GrantTypeImplicit,
oauth2.GrantTypePassword,
oauth2.GrantTypeSwitchUser,
oauth2.GrantTypeSwitchTenant,
)
SupportedDisplayMode = utils.NewStringSet(DisplayPage, PromptTouch)
FullIdTokenGrantTypes = utils.NewStringSet(
oauth2.GrantTypePassword,
oauth2.GrantTypeSwitchUser,
oauth2.GrantTypeSwitchTenant,
)
)
// See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
//goland:noinspection GoUnusedConst
//nolint:gosec
const (
OPMetadataIssuer = "issuer"
OPMetadataAuthEndpoint = "authorization_endpoint"
OPMetadataTokenEndpoint = "token_endpoint"
OPMetadataUserInfoEndpoint = "userinfo_endpoint"
OPMetadataJwkSetURI = "jwks_uri"
OPMetadataRegEndpoint = "registration_endpoint"
OPMetadataScopes = "scopes_supported"
OPMetadataResponseTypes = "response_types_supported"
OPMetadataResponseModes = "response_modes_supported"
OPMetadataGrantTypes = "grant_types_supported"
OPMetadataACRValues = "acr_values_supported"
OPMetadataSubjectTypes = "subject_types_supported"
OPMetadataIdTokenJwsAlg = "id_token_signing_alg_values_supported"
OPMetadataIdTokenJweAlg = "id_token_encryption_alg_values_supported"
OPMetadataIdTokenJweEnc = "id_token_encryption_enc_values_supported"
OPMetadataUserInfoJwsAlg = "userinfo_signing_alg_values_supported"
OPMetadataUserInfoJweAlg = "userinfo_encryption_alg_values_supported"
OPMetadataUserInfoJweEnc = "userinfo_encryption_enc_values_supported"
OPMetadataRequestJwsAlg = "request_object_signing_alg_values_supported"
OPMetadataRequestJweAlg = "request_object_encryption_alg_values_supported"
OPMetadataRequestJweEnc = "request_object_encryption_enc_values_supported"
OPMetadataClientAuthMethod = "token_endpoint_auth_methods_supported"
OPMetadataAuthJwsAlg = "token_endpoint_auth_signing_alg_values_supported"
OPMetadataDisplayValues = "display_values_supported"
OPMetadataClaimTypes = "claim_types_supported"
OPMetadataClaims = "claims_supported"
OPMetadataServiceDocs = "service_documentation"
OPMetadataClaimsLocales = "claims_locales_supported"
OPMetadataUILocales = "ui_locales_supported"
OPMetadataClaimsParams = "claims_parameter_supported"
OPMetadataRequestParams = "request_parameter_supported"
OPMetadataRequestUriParams = "request_uri_parameter_supported"
OPMetadataRequiresRequestUriReg = "require_request_uri_registration"
OPMetadataPolicyUri = "op_policy_uri"
OPMetadataTosUri = "op_tos_uri"
OPMetadataEndSessionEndpoint = "end_session_endpoint"
)
// OPMetadata leverage claims implementations
type OPMetadata struct {
oauth2.FieldClaimsMapper
oauth2.MapClaims
Issuer string `claim:"issuer"`
AuthEndpoint string `claim:"authorization_endpoint"`
TokenEndpoint string `claim:"token_endpoint"`
UserInfoEndpoint string `claim:"userinfo_endpoint"`
JwkSetURI string `claim:"jwks_uri"`
SupportedGrantTypes utils.StringSet `claim:"grant_types_supported"`
SupportedScopes utils.StringSet `claim:"scopes_supported"`
SupportedResponseTypes utils.StringSet `claim:"response_types_supported"`
SupportedACRs utils.StringSet `claim:"acr_values_supported"`
SupportedSubjectTypes utils.StringSet `claim:"subject_types_supported"`
SupportedIdTokenJwsAlg utils.StringSet `claim:"id_token_signing_alg_values_supported"`
SupportedClaims utils.StringSet `claim:"claims_supported"`
}
func (m OPMetadata) MarshalJSON() ([]byte, error) {
return m.FieldClaimsMapper.DoMarshalJSON(m)
}
func (m *OPMetadata) UnmarshalJSON(bytes []byte) error {
return m.FieldClaimsMapper.DoUnmarshalJSON(m, bytes)
}
func (m OPMetadata) Get(claim string) interface{} {
return m.FieldClaimsMapper.Get(m, claim)
}
func (m OPMetadata) Has(claim string) bool {
return m.FieldClaimsMapper.Has(m, claim)
}
func (m *OPMetadata) Set(claim string, value interface{}) {
m.FieldClaimsMapper.Set(m, claim, value)
}
func (m OPMetadata) Values() map[string]interface{} {
return m.FieldClaimsMapper.Values(m)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package openid
import (
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
"net/http"
)
const (
_ = iota
// ErrorSubTypeCodeOidcSlo non-programming error that can occur during oidc RP initiated logout
ErrorSubTypeCodeOidcSlo = security.ErrorTypeCodeOidc + iota<<errorutils.ErrorSubTypeOffset
)
const (
_ = ErrorSubTypeCodeOidcSlo + iota
ErrorCodeOidcSloRp
ErrorCodeOidcSloOp
)
var (
ErrorSubTypeOidcSlo = security.NewErrorSubType(ErrorSubTypeCodeOidcSlo, errors.New("error sub-type: oidc slo"))
// ErrorOidcSloRp errors are displayed as an HTML page with status 400
ErrorOidcSloRp = security.NewCodedError(ErrorCodeOidcSloRp, "SLO rp error")
// ErrorOidcSloOp errors are displayed as an HTML page with status 500
ErrorOidcSloOp = security.NewCodedError(ErrorCodeOidcSloOp, "SLO op error")
)
func newOpenIDExtendedError(oauth2Code string, value interface{}, causes []interface{}) error {
return oauth2.NewOAuth2Error(oauth2.ErrorCodeOpenIDExt, value,
oauth2Code, http.StatusBadRequest, causes...)
}
func NewOpenIDExtendedError(oauth2Code string, value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2Code, value, causes)
}
func NewInteractionRequiredError(value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2.ErrorTranslationInteractionRequired, value, causes)
}
func NewLoginRequiredError(value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2.ErrorTranslationLoginRequired, value, causes)
}
func NewAccountSelectionRequiredError(value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2.ErrorTranslationAcctSelectRequired, value, causes)
}
func NewInvalidRequestURIError(value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2.ErrorTranslationInvalidRequestURI, value, causes)
}
func NewInvalidRequestObjError(value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2.ErrorTranslationInvalidRequestObj, value, causes)
}
func NewRequestNotSupportedError(value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2.ErrorTranslationRequestUnsupported, value, causes)
}
func NewRequestURINotSupportedError(value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2.ErrorTranslationRequestURIUnsupported, value, causes)
}
func NewRegistrationNotSupportedError(value interface{}, causes ...interface{}) error {
return newOpenIDExtendedError(oauth2.ErrorTranslationRegistrationUnsupported, value, causes)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package openid
import (
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
// IdTokenClaims implements oauth2.Claims
type IdTokenClaims struct {
oauth2.FieldClaimsMapper
/*******************************
* Standard Jwt claims
*******************************/
Issuer string `claim:"iss"`
Subject string `claim:"sub"`
Audience oauth2.StringSetClaim `claim:"aud"`
Expire time.Time `claim:"exp"`
IssueAt time.Time `claim:"iat"`
/*******************************
* Standard ID Token claims
*******************************/
/* Standard */
AuthTime time.Time `claim:"auth_time"`
Nonce string `claim:"nonce"`
AuthCtxClassRef string `claim:"acr"`
AuthMethodRef []string `claim:"amr"`
AuthorizedParty string `claim:"azp"`
AccessTokenHash string `claim:"at_hash"`
/* Profile Scope */
FullName string `claim:"name"`
FirstName string `claim:"given_name"`
LastName string `claim:"family_name"`
MiddleName string `claim:"middle_name"`
Nickname string `claim:"nickname"`
PreferredUsername string `claim:"preferred_username"`
ProfileUrl string `claim:"profile"`
PictureUrl string `claim:"picture"`
Website string `claim:"website"`
Gender string `claim:"gender"`
Birthday string `claim:"birthdate"` // ISO 8601:2004 [ISO8601‑2004] YYYY-MM-DD format
ZoneInfo string `claim:"zoneinfo"` // Europe/Paris or America/Los_Angeles
Locale string `claim:"locale"` // Typically ISO 639-1 Alpha-2 [ISO639‑1] language code in lowercase and an ISO 3166-1
UpdatedAt time.Time `claim:"updated_at"`
/* Email Scope */
Email string `claim:"email"`
EmailVerified *bool `claim:"email_verified"`
/* Phone Number Scope */
PhoneNumber string `claim:"phone_number"` // RFC 3966 [RFC3966] e.g. +1 (604) 555-1234;ext=5678
PhoneNumVerified *bool `claim:"phone_number_verified"`
/* Address Scope */
Address *claims.AddressClaim `claim:"address"`
/*******************************
* NFV Additional Claims
*******************************/
/* Profile Scope */
DefaultTenantId string `claim:"default_tenant_id"`
AssignedTenants utils.StringSet `claim:"assigned_tenants"`
Roles utils.StringSet `claim:"roles"`
Permissions utils.StringSet `claim:"permissions"`
/* General Scope */
UserId string `claim:"user_id"`
AccountType string `claim:"account_type"`
TenantId string `claim:"tenant_id"`
TenantExternalId string `claim:"tenant_name"` //for backward compatibility, map to tenant_name
TenantSuspended *bool `claim:"tenant_suspended"`
ProviderId string `claim:"provider_id"`
ProviderName string `claim:"provider_name"`
OrigUsername string `claim:"original_username"`
Currency string `claim:"currency"`
}
func (c *IdTokenClaims) MarshalJSON() ([]byte, error) {
return c.FieldClaimsMapper.DoMarshalJSON(c)
}
func (c *IdTokenClaims) UnmarshalJSON(bytes []byte) error {
return c.FieldClaimsMapper.DoUnmarshalJSON(c, bytes)
}
func (c *IdTokenClaims) Get(claim string) interface{} {
return c.FieldClaimsMapper.Get(c, claim)
}
func (c *IdTokenClaims) Has(claim string) bool {
return c.FieldClaimsMapper.Has(c, claim)
}
func (c *IdTokenClaims) Set(claim string, value interface{}) {
c.FieldClaimsMapper.Set(c, claim, value)
}
func (c *IdTokenClaims) Values() map[string]interface{} {
return c.FieldClaimsMapper.Values(c)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package openid
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
netutil "github.com/cisco-open/go-lanai/pkg/utils/net"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"net/http"
"net/url"
"strings"
)
// The OIDC RP initiated SLO is implemented by a set of handlers
// 1. ConditionalHandler
// 2. SuccessHandler
// 3. EntryPoint
// They work together according to the logic in security/logout/LogoutMiddleware
// First the ConditionalHandler is executed. The ConditionalHandler checks if all the conditions are met in order for us
// to process the logout request. This means when the post_logout_redirect_uri is present, the id_token_hint is valid
// and the redirect url can be verified according to the info presented by the id token. This happens in the ShouldLogout
// method.
// If ShouldLogout is ok, Then the LogoutHandler chain continues, and the user is logged out.
// after the logout handler chain finishes, the SuccessHandler is called, and the user is redirected according to the
// post_logout_redirect_uri
// If Shouldlogout returns error, the logout process is also stopped, and the EntryPoint is called, which will direct user to an error page.
//nolint:gosec // not sure why linter think this is the case: "G101: Potential hardcoded credentials"
var ParameterRedirectUri = "post_logout_redirect_uri"
var ParameterIdTokenHint = "id_token_hint"
var ParameterState = "state"
type SuccessOptions func(opt *SuccessOption)
type SuccessOption struct {
ClientStore oauth2.OAuth2ClientStore
WhitelabelErrorPath string
}
type OidcSuccessHandler struct {
clientStore oauth2.OAuth2ClientStore
fallback security.AuthenticationErrorHandler
}
func (o *OidcSuccessHandler) Order() int {
return order.Highest
}
func NewOidcSuccessHandler(opts ...SuccessOptions) *OidcSuccessHandler {
opt := SuccessOption{}
for _, f := range opts {
f(&opt)
}
return &OidcSuccessHandler{
clientStore: opt.ClientStore,
fallback: redirect.NewRedirectWithURL(opt.WhitelabelErrorPath),
}
}
func (o *OidcSuccessHandler) HandleAuthenticationSuccess(c context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
redirectUri := r.FormValue(ParameterRedirectUri)
if redirectUri == "" {
// as OIDC success handler, we only care about this redirect
return
}
state := r.FormValue(ParameterState)
params := make(map[string]string)
if state != "" {
params[ParameterState] = state
}
redirectUri, err := netutil.AppendRedirectUrl(redirectUri, params)
if err != nil {
o.fallback.HandleAuthenticationError(c, r, rw, err)
return
}
// since the corresponding logout handler already validated the logout request and the redirect uri, we just need to do the redirect.
http.Redirect(rw, r, redirectUri, http.StatusFound)
_, _ = rw.Write([]byte{})
}
type HandlerOptions func(opt *HandlerOption)
type HandlerOption struct {
Dec jwt.JwtDecoder
Issuer security.Issuer
ClientStore oauth2.OAuth2ClientStore
}
type OidcLogoutHandler struct {
dec jwt.JwtDecoder
issuer security.Issuer
clientStore oauth2.OAuth2ClientStore
}
func NewOidcLogoutHandler(opts ...HandlerOptions) *OidcLogoutHandler {
opt := HandlerOption{}
for _, f := range opts {
f(&opt)
}
return &OidcLogoutHandler{
dec: opt.Dec,
issuer: opt.Issuer,
clientStore: opt.ClientStore,
}
}
func (o *OidcLogoutHandler) Order() int {
return order.Highest
}
func (o *OidcLogoutHandler) ShouldLogout(ctx context.Context, request *http.Request, writer http.ResponseWriter, authentication security.Authentication) error {
switch request.Method {
case http.MethodGet:
fallthrough
case http.MethodPost:
case http.MethodPut:
fallthrough
case http.MethodDelete:
fallthrough
default:
return ErrorOidcSloRp.WithMessage("unsupported http verb %v", request.Method)
}
//if logout request doesn't have this, we don't consider it a oidc logout request, and let other handle it.
redirectUri := request.FormValue(ParameterRedirectUri)
if redirectUri == "" {
return nil
}
idTokenValue := request.FormValue(ParameterIdTokenHint)
if strings.TrimSpace(idTokenValue) == "" {
return ErrorOidcSloRp.WithMessage(`id token is required from parameter "%s"`, ParameterIdTokenHint)
}
claims, err := o.dec.Decode(ctx, idTokenValue)
if err != nil {
return ErrorOidcSloRp.WithMessage("id token invalid: %v", err)
}
iss := claims.Get(oauth2.ClaimIssuer)
if iss != o.issuer.Identifier() {
return ErrorOidcSloRp.WithMessage("id token is not issued by this auth server")
}
sub := claims.Get(oauth2.ClaimSubject)
username, err := security.GetUsername(authentication)
if err != nil {
return ErrorOidcSloOp.WithMessage("Couldn't identify current session user")
} else if sub != username {
return ErrorOidcSloRp.WithMessage("logout request rejected because id token is not from the current session's user.")
}
clientId := claims.Get(oauth2.ClaimAudience).(string)
client, err := auth.LoadAndValidateClientId(ctx, clientId, o.clientStore)
if err != nil {
return ErrorOidcSloOp.WithMessage("error loading client %s", clientId)
}
_, err = auth.ResolveRedirectUri(ctx, redirectUri, client)
if err != nil {
return ErrorOidcSloRp.WithMessage("redirect url %s is not registered by client %s", redirectUri, clientId)
}
r, err := url.Parse(redirectUri)
if err != nil {
return ErrorOidcSloRp.WithMessage("redirect url %s is not a valid url", redirectUri)
} else {
if r.RawQuery != "" {
return ErrorOidcSloRp.WithMessage("redirect url %s should not contain query parameter", redirectUri)
}
}
return nil
}
func (o *OidcLogoutHandler) HandleLogout(ctx context.Context, request *http.Request, writer http.ResponseWriter, authentication security.Authentication) error {
//no op, because the default logout handler is sufficient (deleting the current session etc.)
return nil
}
type EpOptions func(opt *EpOption)
type EpOption struct {
WhitelabelErrorPath string
}
type OidcEntryPoint struct {
fallback security.AuthenticationEntryPoint
}
func NewOidcEntryPoint(opts ...EpOptions) *OidcEntryPoint {
opt := EpOption{}
for _, f := range opts {
f(&opt)
}
return &OidcEntryPoint{
fallback: redirect.NewRedirectWithURL(opt.WhitelabelErrorPath),
}
}
func (o *OidcEntryPoint) Commence(ctx context.Context, request *http.Request, writer http.ResponseWriter, err error) {
if !errors.Is(err, ErrorSubTypeOidcSlo) {
return
}
switch {
case errors.Is(err, ErrorOidcSloRp):
fallthrough //currently we don't have any rp or op specific error handling requirements.
case errors.Is(err, ErrorOidcSloOp):
fallthrough
default:
o.fallback.Commence(ctx, request, writer, err)
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package openid
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/idp"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
"github.com/cisco-open/go-lanai/pkg/utils"
)
/************************
OPMetadata Specs
************************/
const (
OPMetaExtraSourceIDPManager = "idpManager"
)
var errorOPMetaClaimNotAvailable = fmt.Errorf("claim N/A")
type constantClaimSpec struct {
val interface{}
}
func (s constantClaimSpec) Calculate(_ context.Context, _ *claims.FactoryOption) (v interface{}, err error) {
return s.val, nil
}
func (s constantClaimSpec) Required(_ context.Context, _ *claims.FactoryOption) bool {
return false
}
type opMetaClaimSpec struct {
fn claims.ClaimFactoryFunc
}
func (s opMetaClaimSpec) Calculate(ctx context.Context, opt *claims.FactoryOption) (v interface{}, err error) {
return s.fn(ctx, opt)
}
func (s opMetaClaimSpec) Required(_ context.Context, _ *claims.FactoryOption) bool {
return false
}
func opMetaFixedSet(strings ...string) claims.ClaimSpec {
return constantClaimSpec{
val: utils.NewStringSet(strings...),
}
}
func opMetaFixedBool(v bool) claims.ClaimSpec {
return constantClaimSpec{
val: v,
}
}
func opMetaAcrValues(acrLevels ...int) claims.ClaimSpec {
return opMetaClaimSpec{
fn: func(_ context.Context, opt *claims.FactoryOption) (v interface{}, err error) {
if opt.Issuer == nil {
return nil, errorOPMetaClaimNotAvailable
}
values := utils.NewStringSet()
for _, lvl := range acrLevels {
values.Add(opt.Issuer.LevelOfAssurance(lvl))
}
return values, nil
},
}
}
func opMetaEndpoint(epName string) claims.ClaimSpec {
return opMetaClaimSpec{
fn: func(ctx context.Context, opt *claims.FactoryOption) (v interface{}, err error) {
if opt.ExtraSource == nil || opt.Issuer == nil {
return nil, errorOPMetaClaimNotAvailable
}
relative, ok := opt.ExtraSource[epName].(string)
if !ok {
return nil, errorOPMetaClaimNotAvailable
}
// figure out domain
idpMgt, ok := opt.ExtraSource[OPMetaExtraSourceIDPManager].(idp.IdentityProviderManager)
if !ok {
return nil, errorOPMetaClaimNotAvailable
}
var domain string
for _, flow := range []idp.AuthenticationFlow{idp.InternalIdpForm, idp.ExternalIdpSAML} {
idps := idpMgt.GetIdentityProvidersWithFlow(ctx, flow)
if len(idps) != 0 {
domain = idps[0].Domain()
break
}
}
uri, e := opt.Issuer.BuildUrl(func(opt *security.UrlBuilderOption) {
opt.FQDN = domain
opt.Path = relative
})
if e != nil {
return nil, e
}
return uri.String(), nil
},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package openid
import (
"context"
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
)
/*****************************
ID Token Enhancer
*****************************/
var (
scopedSpecs = map[string]map[string]claims.ClaimSpec{
oauth2.ScopeOidcProfile: claims.ProfileScopeSpecs,
oauth2.ScopeOidcEmail: claims.EmailScopeSpecs,
oauth2.ScopeOidcPhone: claims.PhoneScopeSpecs,
oauth2.ScopeOidcAddress: claims.AddressScopeSpecs,
}
defaultSpecs = []map[string]claims.ClaimSpec{
claims.IdTokenBasicSpecs,
}
fullSpecs = []map[string]claims.ClaimSpec{
claims.IdTokenBasicSpecs,
claims.ProfileScopeSpecs,
claims.EmailScopeSpecs,
claims.PhoneScopeSpecs,
claims.AddressScopeSpecs,
}
)
type EnhancerOptions func(opt *EnhancerOption)
type EnhancerOption struct {
Issuer security.Issuer
JwtEncoder jwt.JwtEncoder
}
// OpenIDTokenEnhancer implements order.Ordered and TokenEnhancer
// OpenIDTokenEnhancer generate OpenID ID Token and set it to token details
//goland:noinspection GoNameStartsWithPackageName
type OpenIDTokenEnhancer struct {
issuer security.Issuer
jwtEncoder jwt.JwtEncoder
}
func NewOpenIDTokenEnhancer(opts ...EnhancerOptions) *OpenIDTokenEnhancer {
opt := EnhancerOption{}
for _, fn := range opts {
fn(&opt)
}
return &OpenIDTokenEnhancer{
issuer: opt.Issuer,
jwtEncoder: opt.JwtEncoder,
}
}
func (oe *OpenIDTokenEnhancer) Order() int {
return auth.TokenEnhancerOrderTokenDetails
}
func (oe *OpenIDTokenEnhancer) Enhance(ctx context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
if oe.shouldSkip(oauth) {
return token, nil
}
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError("unsupported token implementation %T", t)
}
specs := oe.determineClaimSpecs(oauth.OAuth2Request())
requested := oe.determineRequestedClaims(oauth.OAuth2Request())
c := IdTokenClaims{}
e := claims.Populate(ctx, &c,
claims.WithSpecs(specs...),
claims.WithSource(oauth),
claims.WithIssuer(oe.issuer),
claims.WithAccessToken(token),
claims.WithRequestedClaims(requested, fullSpecs...),
)
if e != nil {
return nil, oauth2.NewInternalError(e)
}
idToken, e := oe.jwtEncoder.Encode(ctx, &c)
if e != nil {
return nil, oauth2.NewInternalError(e)
}
t.PutDetails(oauth2.JsonFieldIDTokenValue, idToken)
return t, nil
}
func (oe *OpenIDTokenEnhancer) shouldSkip(oauth oauth2.Authentication) bool {
req := oauth.OAuth2Request()
return req == nil ||
// grant type not supported
!SupportedGrantTypes.Has(req.GrantType()) ||
// openid scope not requested
!req.Scopes().Has(oauth2.ScopeOidc) ||
// implicit flow without id_token response type
req.ResponseTypes().Has("token") && !req.ResponseTypes().Has("id_token") ||
// not user authorized
oauth.UserAuthentication() == nil
}
// determine id_token claims based on scopes defined by Core Spec 5.4: https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims
// Note 1: per spec, if response_type is token/code, access token will be issued,
// therefore profile, email, phone and address is returned in user info, not in id_token
// Note 2: outside the OIDC spec, we have password, switch context grant types that doesn't use response type.
// For legacy support, we still return full id_token regardless the scope being requested
func (oe *OpenIDTokenEnhancer) determineClaimSpecs(request oauth2.OAuth2Request) []map[string]claims.ClaimSpec {
if request == nil || request.Scopes() == nil || !request.Approved() {
return defaultSpecs
}
switch {
case request.ResponseTypes().Has("code") || request.ResponseTypes().Has("token"):
return defaultSpecs
case FullIdTokenGrantTypes.Has(request.GrantType()):
return fullSpecs
}
specs := make([]map[string]claims.ClaimSpec, len(defaultSpecs), len(defaultSpecs)+len(request.Scopes()))
for i, spec := range defaultSpecs {
specs[i] = spec
}
scopes := request.Scopes()
for scope, spec := range scopedSpecs {
if scopes.Has(scope) {
specs = append(specs, spec)
}
}
return specs
}
func (oe *OpenIDTokenEnhancer) determineRequestedClaims(request oauth2.OAuth2Request) claims.RequestedClaims {
raw, ok := request.Extensions()[oauth2.ParameterClaims].(string)
if !ok {
return nil
}
cr := ClaimsRequest{}
if e := json.Unmarshal([]byte(raw), &cr); e != nil {
return nil
}
return cr.IdToken
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"encoding/gob"
"github.com/cisco-open/go-lanai/pkg/log"
)
var logger = log.New("OAuth2.Auth")
func init() {
gob.Register((*AuthorizeRequest)(nil))
gob.Register((*TokenRequest)(nil))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"net/url"
)
/********************************************
Helper functions for OAuth2 Redirects
********************************************/
func composeRedirectUrl(c context.Context, r *AuthorizeRequest, values map[string]string, useFragment bool) (string, error) {
redirectUrl, ok := findRedirectUri(c, r)
if !ok {
return "", fmt.Errorf("redirect URI is unknown")
}
if state, ok := findRedirectState(c, r); ok {
values[oauth2.ParameterState] = state
}
return appendRedirectUrl(redirectUrl, values)
}
func appendRedirectUrl(redirectUrl string, params map[string]string) (string, error) {
loc, e := url.ParseRequestURI(redirectUrl)
if e != nil || !loc.IsAbs() {
return "", oauth2.NewInvalidRedirectUriError("invalid redirect_uri")
}
// TODO support fragments
query := loc.Query()
for k, v := range params {
query.Add(k, v)
}
loc.RawQuery = query.Encode()
return loc.String(), nil
}
func findRedirectUri(c context.Context, r *AuthorizeRequest) (string, bool) {
value, ok := c.Value(oauth2.CtxKeyResolvedAuthorizeRedirect).(string)
if !ok && r != nil {
value = r.RedirectUri
ok = true
}
return value, ok
}
func findRedirectState(c context.Context, r *AuthorizeRequest) (string, bool) {
value, ok := c.Value(oauth2.CtxKeyResolvedAuthorizeState).(string)
if !ok && r != nil {
value = r.State
ok = true
}
return value, ok
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package revoke
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/session"
"net/http"
"strings"
)
const (
bearerTokenPrefix = "Bearer "
)
type HanlderOptions func(opt *HanlderOption)
type HanlderOption struct {
Revoker auth.AccessRevoker
}
// TokenRevokingLogoutHandler
/**
* GET method: used for logout by the session controlled clients. The client send user to this endpoint and the session
* is invalidated. As a result, the tokens controlled by this session is invalidated (See the NfvClientDetails.useSessionTimeout
* properties). It's also used by SSO logout (OIDC, and SAML GET Binding). In those case, the session is invalidated, and the
* token controlled by the session is invalidated in the same way.
*
* POST method: used for logout by SSO logout (SAML POST Binding). The session is invalidated, and the token controlled
* by the session is invalidated (same as the GET method).
*
* PUT/DELETE method: used for token revocation. Typically for service login or token revocation. We grab token
* from header and revoke this only this token.
*
* @author Livan Du
* Created on 2018-05-04
*/
type TokenRevokingLogoutHandler struct {
revoker auth.AccessRevoker
}
func NewTokenRevokingLogoutHandler(opts ...HanlderOptions) *TokenRevokingLogoutHandler {
opt := HanlderOption{}
for _, f := range opts {
f(&opt)
}
return &TokenRevokingLogoutHandler{
revoker: opt.Revoker,
}
}
func (h TokenRevokingLogoutHandler) HandleLogout(ctx context.Context, r *http.Request, rw http.ResponseWriter, auth security.Authentication) error {
switch r.Method {
case http.MethodGet:
fallthrough
case http.MethodPost:
return h.handleGetOrPost(ctx, auth)
case http.MethodPut:
fallthrough
case http.MethodDelete:
return h.handleDefault(ctx, r)
}
return nil
}
func (h TokenRevokingLogoutHandler) handleGetOrPost(ctx context.Context, auth security.Authentication) error {
defer func() {
security.MustClear(ctx)
session.MustSet(ctx, nil)
}()
s := session.Get(ctx)
if s == nil {
logger.WithContext(ctx).Debugf("invalid use of GET/POST /logout endpoint. session is not found")
return nil
}
if e := h.revoker.RevokeWithSessionId(ctx, s.GetID(), s.Name()); e != nil {
logger.WithContext(ctx).Warnf("unable to revoke tokens with session %s: %v", s.GetID(), e)
return e
}
return nil
}
// In case of PUT, DELETE, PATCH etc, we don't clean authentication. Instead, we invalidate access token carried by header
func (h TokenRevokingLogoutHandler) handleDefault(ctx context.Context, r *http.Request) error {
// grab token
tokenValue, e := h.extractAccessToken(ctx, r)
if e != nil {
logger.WithContext(ctx).Warnf("unable to revoke token: %v", e)
return nil
}
if e := h.revoker.RevokeWithTokenValue(ctx, tokenValue, auth.RevokerHintAccessToken); e != nil {
logger.WithContext(ctx).Warnf("unable to revoke token with value %s: %v", log.Capped(tokenValue, 20), e)
return e
}
return nil
}
func (h TokenRevokingLogoutHandler) extractAccessToken(ctx context.Context, r *http.Request) (string, error) {
// try header first
header := r.Header.Get("Authorization")
if strings.HasPrefix(strings.ToUpper(header), strings.ToUpper(bearerTokenPrefix)) {
return header[len(bearerTokenPrefix):], nil
}
// then try param
value := r.FormValue(oauth2.ParameterAccessToken)
if strings.TrimSpace(value) == "" {
return "", fmt.Errorf(`access token is required either from "Authorization" header or parameter "%s"`, oauth2.ParameterAccessToken)
}
return value, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package revoke
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/security/session/common"
)
type RevokerOptions func(opt *RevokerOption)
type RevokerOption struct {
AuthRegistry auth.AuthorizationRegistry
SessionName string
SessionStore session.Store
TokenStoreReader oauth2.TokenStoreReader
}
// DefaultAccessRevoker implements auth.AccessRevoker
type DefaultAccessRevoker struct {
authRegistry auth.AuthorizationRegistry
sessionName string
sessionStore session.Store
tokenStoreReader oauth2.TokenStoreReader
}
func NewDefaultAccessRevoker(opts ...RevokerOptions) *DefaultAccessRevoker {
opt := RevokerOption{
SessionName: common.DefaultName,
}
for _, f := range opts {
f(&opt)
}
return &DefaultAccessRevoker{
authRegistry: opt.AuthRegistry,
sessionName: opt.SessionName,
sessionStore: opt.SessionStore,
tokenStoreReader: opt.TokenStoreReader,
}
}
func (r DefaultAccessRevoker) RevokeWithSessionId(ctx context.Context, sessionId string, sessionName string) (err error) {
// expire session
if s, e := r.sessionStore.Get(sessionId, sessionName); e == nil && s != nil {
if e := s.ExpireNow(ctx); e != nil {
logger.WithContext(ctx).Warnf("Unable to expire session for session ID [%s]: %v", s.GetID(), e)
err = e
}
}
// revoke all tokens
if e := r.authRegistry.RevokeSessionAccess(ctx, sessionId, true); e != nil {
return e
}
return
}
func (r DefaultAccessRevoker) RevokeWithUsername(ctx context.Context, username string, revokeRefreshToken bool) (err error) {
// expire all sessions
if e := r.sessionStore.WithContext(ctx).InvalidateByPrincipalName(username, r.sessionName); e != nil {
logger.WithContext(ctx).Warnf("Unable to expire session for username [%s]: %v", username, e)
err = e
}
// revoke all tokens
if e := r.authRegistry.RevokeUserAccess(ctx, username, revokeRefreshToken); e != nil {
return e
}
return
}
func (r DefaultAccessRevoker) RevokeWithClientId(ctx context.Context, clientId string, revokeRefreshToken bool) error {
return r.authRegistry.RevokeClientAccess(ctx, clientId, true)
}
func (r DefaultAccessRevoker) RevokeWithTokenValue(ctx context.Context, tokenValue string, hint auth.RevokerTokenHint) error {
switch hint {
case auth.RevokerHintAccessToken:
token, e := r.tokenStoreReader.ReadAccessToken(ctx, tokenValue)
if e != nil {
return e
}
return r.authRegistry.RevokeAccessToken(ctx, token)
case auth.RevokerHintRefreshToken:
token, e := r.tokenStoreReader.ReadRefreshToken(ctx, tokenValue)
if e != nil {
return e
}
return r.authRegistry.RevokeRefreshToken(ctx, token)
default:
return fmt.Errorf("unsupported revoker token hint")
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package revoke
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/logout"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/utils"
"net/http"
"net/url"
)
type SuccessOptions func(opt *SuccessOption)
type SuccessOption struct {
ClientStore oauth2.OAuth2ClientStore
RedirectWhitelist utils.StringSet
WhitelabelErrorPath string
WhitelabelLoggedOutPath string
}
// TokenRevokeSuccessHandler implements security.AuthenticationSuccessHandler
type TokenRevokeSuccessHandler struct {
clientStore oauth2.OAuth2ClientStore
whitelist utils.StringSet
fallback security.AuthenticationErrorHandler
defaultSuccessHandler security.AuthenticationSuccessHandler
}
func NewTokenRevokeSuccessHandler(opts ...SuccessOptions) *TokenRevokeSuccessHandler {
opt := SuccessOption{}
for _, f := range opts {
f(&opt)
}
return &TokenRevokeSuccessHandler{
clientStore: opt.ClientStore,
fallback: redirect.NewRedirectWithURL(opt.WhitelabelErrorPath),
whitelist: opt.RedirectWhitelist,
defaultSuccessHandler: redirect.NewRedirectWithRelativePath(opt.WhitelabelLoggedOutPath, true),
}
}
func (h TokenRevokeSuccessHandler) HandleAuthenticationSuccess(ctx context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
switch r.Method {
case http.MethodGet:
fallthrough
case http.MethodPost:
h.redirect(ctx, r, rw, from, to)
case http.MethodPut:
fallthrough
case http.MethodDelete:
fallthrough
default:
h.status(ctx, rw)
}
}
func (h TokenRevokeSuccessHandler) redirect(ctx context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
// Note: we don't have error handling alternatives (except for panic)
redirectUri := r.FormValue(oauth2.ParameterRedirectUri)
if redirectUri == "" {
h.defaultSuccessHandler.HandleAuthenticationSuccess(ctx, r, rw, from, to)
return
}
clientId := r.FormValue(oauth2.ParameterClientId)
client, e := auth.LoadAndValidateClientId(ctx, clientId, h.clientStore)
if e != nil {
h.fallback.HandleAuthenticationError(ctx, r, rw, e)
return
}
resolved, e := auth.ResolveRedirectUri(ctx, redirectUri, client)
if e != nil {
// try resolve from whitelist
if !h.isWhitelisted(ctx, redirectUri) {
h.fallback.HandleAuthenticationError(ctx, r, rw, e)
return
}
resolved = redirectUri
}
h.doRedirect(ctx, r, rw, resolved)
}
func (h TokenRevokeSuccessHandler) doRedirect(ctx context.Context, r *http.Request, rw http.ResponseWriter, redirectUri string) {
redirectUrl := h.appendWarnings(ctx, redirectUri)
http.Redirect(rw, r, redirectUrl, http.StatusFound)
_, _ = rw.Write([]byte{})
}
// In case of PUT, DELETE, PATCH etc, we don't clean authentication. Instead, we invalidate access token carried by header
func (h TokenRevokeSuccessHandler) status(_ context.Context, rw http.ResponseWriter) {
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte{})
}
func (h TokenRevokeSuccessHandler) isWhitelisted(_ context.Context, redirect string) bool {
for pattern, _ := range h.whitelist {
matcher, e := auth.NewWildcardUrlMatcher(pattern)
if e != nil {
continue
}
if matches, e := matcher.Matches(redirect); e == nil && matches {
return true
}
}
return false
}
func (h TokenRevokeSuccessHandler) appendWarnings(ctx context.Context, redirect string) string {
warnings := logout.GetWarnings(ctx)
if len(warnings) == 0 {
return redirect
}
redirectUrl, e := url.Parse(redirect)
if e != nil {
return redirect
}
q := redirectUrl.Query()
for _, w := range warnings {
q.Add("warning", w.Error())
}
redirectUrl.RawQuery = q.Encode()
return redirectUrl.String()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/common"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/google/uuid"
"time"
)
var (
endOfWorld = time.Date(2999, time.December, 31, 23, 59, 59, 0, time.UTC)
)
type AuthorizationService interface {
CreateAuthentication(ctx context.Context, request oauth2.OAuth2Request, userAuth security.Authentication) (oauth2.Authentication, error)
SwitchAuthentication(ctx context.Context, request oauth2.OAuth2Request, userAuth security.Authentication, src oauth2.Authentication) (oauth2.Authentication, error)
CreateAccessToken(ctx context.Context, oauth oauth2.Authentication) (oauth2.AccessToken, error)
RefreshAccessToken(ctx context.Context, oauth oauth2.Authentication, refreshToken oauth2.RefreshToken) (oauth2.AccessToken, error)
}
/****************************
Default implementation
****************************/
type DASOptions func(*DASOption)
type DASOption struct {
DetailsFactory *common.ContextDetailsFactory
ClientStore oauth2.OAuth2ClientStore
AccountStore security.AccountStore
TenantStore security.TenantStore
ProviderStore security.ProviderStore
Issuer security.Issuer
TokenStore TokenStore
TokenEnhancers []TokenEnhancer
PostTokenEnhancers []TokenEnhancer
}
// DefaultAuthorizationService implements AuthorizationService
type DefaultAuthorizationService struct {
detailsFactory *common.ContextDetailsFactory
clientStore oauth2.OAuth2ClientStore
accountStore security.AccountStore
tenantStore security.TenantStore
providerStore security.ProviderStore
tokenStore TokenStore
tokenEnhancer TokenEnhancer
postTokenEnhancer TokenEnhancer
}
func NewDefaultAuthorizationService(opts ...DASOptions) *DefaultAuthorizationService {
basicEnhancer := BasicClaimsTokenEnhancer{}
refreshTokenEnhancer := RefreshTokenEnhancer{}
conf := DASOption{
TokenEnhancers: []TokenEnhancer{
&ExpiryTokenEnhancer{},
&basicEnhancer,
&ResourceIdTokenEnhancer{},
&DetailsTokenEnhancer{},
&refreshTokenEnhancer,
},
PostTokenEnhancers: []TokenEnhancer{},
}
for _, opt := range opts {
opt(&conf)
}
basicEnhancer.issuer = conf.Issuer
refreshTokenEnhancer.issuer = conf.Issuer
refreshTokenEnhancer.tokenStore = conf.TokenStore
return &DefaultAuthorizationService{
detailsFactory: conf.DetailsFactory,
clientStore: conf.ClientStore,
accountStore: conf.AccountStore,
tenantStore: conf.TenantStore,
providerStore: conf.ProviderStore,
tokenStore: conf.TokenStore,
tokenEnhancer: NewCompositeTokenEnhancer(conf.TokenEnhancers...),
postTokenEnhancer: NewCompositeTokenEnhancer(conf.PostTokenEnhancers...),
}
}
func (s *DefaultAuthorizationService) CreateAuthentication(ctx context.Context,
request oauth2.OAuth2Request, user security.Authentication) (oauth oauth2.Authentication, err error) {
userAuth := ConvertToOAuthUserAuthentication(user)
details, err := s.createContextDetails(ctx, request, userAuth, nil)
if err != nil {
return
}
// reconstruct user auth based on newly loaded facts (account may changed)
if userAuth, err = s.createUserAuthentication(ctx, request, userAuth); err != nil {
return
}
// create the result
oauth = oauth2.NewAuthentication(func(conf *oauth2.AuthOption) {
conf.Request = request
conf.UserAuth = userAuth
conf.Details = details
})
return
}
func (s *DefaultAuthorizationService) SwitchAuthentication(ctx context.Context,
request oauth2.OAuth2Request, user security.Authentication,
src oauth2.Authentication) (oauth oauth2.Authentication, err error) {
userAuth := ConvertToOAuthUserAuthentication(user)
details, err := s.createContextDetails(ctx, request, userAuth, src)
if err != nil {
return
}
// reconstruct user auth based on newly loaded facts (account may changed)
if userAuth, err = s.createUserAuthentication(ctx, request, userAuth); err != nil {
return
}
// create the result
oauth = oauth2.NewAuthentication(func(conf *oauth2.AuthOption) {
conf.Request = request
conf.UserAuth = userAuth
conf.Details = details
})
return
}
func (s *DefaultAuthorizationService) CreateAccessToken(c context.Context, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
token := s.reuseOrNewAccessToken(c, oauth)
enhanced, e := s.tokenEnhancer.Enhance(c, token, oauth)
if e != nil {
return nil, e
}
// save token
saved, e := s.tokenStore.SaveAccessToken(c, enhanced, oauth)
if e != nil {
return nil, e
}
return s.postTokenEnhancer.Enhance(c, saved, oauth)
}
func (s *DefaultAuthorizationService) RefreshAccessToken(c context.Context, oauth oauth2.Authentication, refreshToken oauth2.RefreshToken) (oauth2.AccessToken, error) {
// we first remove existing access token associated with this refresh token
// this functionality is necessary so refresh tokens can't be used to create an unlimited number of access tokens.
_ = s.tokenStore.RemoveAccessToken(c, refreshToken)
token := s.reuseOrNewAccessToken(c, oauth)
token.SetRefreshToken(refreshToken)
enhanced, e := s.tokenEnhancer.Enhance(c, token, oauth)
if e != nil {
return nil, e
}
// save token
saved, e := s.tokenStore.SaveAccessToken(c, enhanced, oauth)
if e != nil {
return nil, e
}
return s.postTokenEnhancer.Enhance(c, saved, oauth)
}
/*
***************************
Authorization Helpers
***************************
*/
type authFacts struct {
request oauth2.OAuth2Request
client oauth2.OAuth2Client
account security.Account
tenant *security.Tenant
provider *security.Provider
source oauth2.Authentication
userAuth oauth2.UserAuthentication
}
func (s *DefaultAuthorizationService) createContextDetails(ctx context.Context,
request oauth2.OAuth2Request, userAuth oauth2.UserAuthentication,
src oauth2.Authentication) (security.ContextDetails, error) {
now := time.Now().UTC()
facts, e := s.loadAndVerifyFacts(ctx, request, userAuth)
if e != nil {
return nil, e
}
mutableSetter := utils.FindMutableContext(ctx)
if mutableSetter == nil {
return nil, newImmutableContextError()
}
mutableSetter.Set(oauth2.CtxKeyAuthenticatedClient, facts.client)
mutableSetter.Set(oauth2.CtxKeyAuthenticatedAccount, facts.account)
mutableSetter.Set(oauth2.CtxKeyAuthorizedTenant, facts.tenant)
mutableSetter.Set(oauth2.CtxKeyAuthorizedProvider, facts.provider)
mutableSetter.Set(oauth2.CtxKeyUserAuthentication, facts.userAuth)
mutableSetter.Set(oauth2.CtxKeyAuthorizationIssueTime, now)
if src != nil {
facts.source = src
mutableSetter.Set(oauth2.CtxKeySourceAuthentication, src)
}
// expiry
expiry := s.determineExpiryTime(ctx, request, facts)
if !expiry.IsZero() {
mutableSetter.Set(oauth2.CtxKeyAuthorizationExpiryTime, expiry)
}
// auth time
authTime := s.determineAuthenticationTime(ctx, userAuth, facts)
if !authTime.IsZero() {
mutableSetter.Set(oauth2.CtxKeyAuthenticationTime, authTime)
}
// create context details
return s.detailsFactory.New(ctx, request) //nolint:contextcheck // this is expected usage of MutableCtx
}
func (s *DefaultAuthorizationService) createUserAuthentication(ctx context.Context, _ oauth2.OAuth2Request, userAuth oauth2.UserAuthentication) (oauth2.UserAuthentication, error) {
if userAuth == nil {
return nil, nil
}
account, ok := ctx.Value(oauth2.CtxKeyAuthenticatedAccount).(security.Account)
if !ok {
return userAuth, nil
}
permissions := map[string]interface{}{}
for _, v := range account.Permissions() {
permissions[v] = true
}
details, ok := userAuth.Details().(map[string]interface{})
if !ok || details == nil {
details = map[string]interface{}{}
}
return oauth2.NewUserAuthentication(func(opt *oauth2.UserAuthOption) {
opt.Principal = account.Username()
opt.Permissions = permissions
opt.State = userAuth.State()
opt.Details = details
}), nil
}
func (s *DefaultAuthorizationService) loadAndVerifyFacts(ctx context.Context, request oauth2.OAuth2Request, userAuth security.Authentication) (*authFacts, error) {
client := RetrieveAuthenticatedClient(ctx)
if client == nil {
return nil, newInvalidClientError()
}
account, err := s.loadAccount(ctx, request, userAuth)
if err != nil {
return nil, err
} else if account != nil && (account.Locked() || account.Disabled()) {
return nil, newInvalidUserError("unsupported user's account locked or disabled")
}
defaultTenantId, assignedTenants, err := common.ResolveClientUserTenants(ctx, account, client)
if err != nil {
return nil, newInvalidTenantForUserError(fmt.Errorf("can't resolve account [%T] and client's [%T] tenants", account, client))
}
tenant, err := s.loadTenant(ctx, request, defaultTenantId)
if err != nil {
return nil, err
}
if err = s.verifyTenantAccess(ctx, tenant, assignedTenants); err != nil {
return nil, err
}
provider, err := s.loadProvider(ctx, request, tenant)
if err != nil {
return nil, err
}
if account == nil { // at this point we have all the information we need if it's only client auth
return &authFacts{
request: request,
client: client,
tenant: tenant,
provider: provider,
}, nil
}
if finalizer, ok := s.accountStore.(security.AccountFinalizer); ok {
newAccount, err := finalizer.Finalize(ctx, account, security.FinalizeWithTenant(tenant))
if err != nil {
return nil, err
}
// Check that the ID and username have not been tampered with
if newAccount.ID() != account.ID() || newAccount.Username() != account.Username() {
return nil, newTamperedIDOrUsernameError()
}
// Check tenancy has not been tampered with
if _, ok := newAccount.(security.AccountTenancy); !ok {
return nil, newTamperedTenancyError()
}
if newAccount.(security.AccountTenancy).DefaultDesignatedTenantId() != account.(security.AccountTenancy).DefaultDesignatedTenantId() {
return nil, newTamperedTenancyError()
}
if !utils.NewStringSet(newAccount.(security.AccountTenancy).DesignatedTenantIds()...).Equals(utils.NewStringSet(account.(security.AccountTenancy).DesignatedTenantIds()...)) {
return nil, newTamperedTenancyError()
}
account = newAccount
}
// after account finalizer, we can re-create the userAuth security.Authentication,
// and then return it from here
// The Principal and State cannot change. Details and Permissions may change
// can use something similar to auth.ConvertToOAuthUserAuthentication to grab things from, but then
// edit the permissions
// So keep everything from userAuth, we only need permissions from account
// Check that the account userID and username did not change from the finalizer
newUserAuth := ConvertToOAuthUserAuthentication(
userAuth,
ConvertWithSkipTypeCheck(true),
func(option *ConvertOptions) {
option.AppendUserAuthOptions(func(userAuth security.Authentication) oauth2.UserAuthOptions {
return func(opt *oauth2.UserAuthOption) {
opt.Permissions = userAuth.Permissions()
}
})
})
return &authFacts{
request: request,
client: client,
account: account,
tenant: tenant,
provider: provider,
userAuth: newUserAuth,
}, nil
}
func (s *DefaultAuthorizationService) loadAccount(
ctx context.Context,
req oauth2.OAuth2Request,
userAuth security.Authentication,
) (security.Account, error) {
if userAuth == nil {
return nil, nil
}
// sanity check, this should not happen
if userAuth.State() < security.StateAuthenticated || userAuth.Principal() == nil {
return nil, newUnauthenticatedUserError()
}
username, err := security.GetUsername(userAuth)
if err != nil {
return nil, newInvalidUserError(err)
}
acct, err := s.accountStore.LoadAccountByUsername(ctx, username)
if err != nil {
return nil, newInvalidUserError(err)
}
return acct, nil
}
func (s *DefaultAuthorizationService) loadTenant(
ctx context.Context,
request oauth2.OAuth2Request,
defaultTenantId string,
) (*security.Tenant, error) {
// extract tenant id or name
tenantId, idOk := request.Parameters()[oauth2.ParameterTenantId]
tenantExternalId, nOk := request.Parameters()[oauth2.ParameterTenantExternalId]
// TODO review this logic regarding to wildcard "*" and tenant auto-selection
// Note: default tenant ID might be wildcard "*", in such case, we don't load tenant if tenant is not selected
if (!idOk || tenantId == "") && (!nOk || tenantExternalId == "") && defaultTenantId != security.SpecialTenantIdWildcard {
tenantId = defaultTenantId
}
var tenant *security.Tenant
var e error
if tenantId != "" {
tenant, e = s.tenantStore.LoadTenantById(ctx, tenantId)
if e != nil {
return nil, newInvalidTenantForUserError(fmt.Sprintf("error loading tenant with id [%s]", tenantId))
}
}
if tenantExternalId != "" {
tenant, e = s.tenantStore.LoadTenantByExternalId(ctx, tenantExternalId)
if e != nil {
return nil, newInvalidTenantForUserError(fmt.Sprintf("error loading tenant with externalId [%s]", tenantExternalId))
}
}
return tenant, nil
}
func (s *DefaultAuthorizationService) verifyTenantAccess(ctx context.Context, tenant *security.Tenant, assignedTenantIds []string) error {
if tenant == nil {
return nil
}
tenantIds := utils.NewStringSet(assignedTenantIds...)
if tenantIds.Has(security.SpecialTenantIdWildcard) {
return nil
}
if !tenancy.AnyHasDescendant(ctx, tenantIds, tenant.Id) {
return oauth2.NewInvalidGrantError("user does not have access to specified tenant")
}
return nil
}
func (s *DefaultAuthorizationService) loadProvider(ctx context.Context, _ oauth2.OAuth2Request, tenant *security.Tenant) (*security.Provider, error) {
if tenant == nil {
return nil, nil
}
providerId := tenant.ProviderId
if providerId == "" {
return nil, newInvalidProviderError("provider ID is not avalilable")
}
provider, e := s.providerStore.LoadProviderById(ctx, providerId)
if e != nil {
return nil, newInvalidProviderError(fmt.Sprintf("tenant [%s]'s provider is invalid", tenant.DisplayName))
}
return provider, nil
}
func (s *DefaultAuthorizationService) determineExpiryTime(ctx context.Context, _ oauth2.OAuth2Request, facts *authFacts) (expiry time.Time) {
max := endOfWorld
// When switching context, expiry should no later than original expiry time
if facts.source != nil {
if srcAuth, ok := facts.source.Details().(security.AuthenticationDetails); ok {
max = srcAuth.ExpiryTime()
}
}
if facts.client.AccessTokenValidity() == 0 {
if max == endOfWorld {
return
} else {
return max
}
}
issueTime := ctx.Value(oauth2.CtxKeyAuthorizationIssueTime).(time.Time)
expiry = issueTime.Add(facts.client.AccessTokenValidity()).UTC()
return minTime(expiry, max)
}
func (s *DefaultAuthorizationService) determineAuthenticationTime(ctx context.Context, userAuth security.Authentication, facts *authFacts) (authTime time.Time) {
if facts.source != nil {
if srcAuth, ok := facts.source.Details().(security.AuthenticationDetails); ok {
return srcAuth.AuthenticationTime()
}
}
authTime = security.DetermineAuthenticationTime(ctx, userAuth)
return
}
/*
***************************
Helpers
***************************
*/
func (s *DefaultAuthorizationService) reuseOrNewAccessToken(c context.Context, oauth oauth2.Authentication) *oauth2.DefaultAccessToken {
existing, e := s.tokenStore.ReusableAccessToken(c, oauth)
if e != nil || existing == nil {
return oauth2.NewDefaultAccessToken(uuid.New().String())
} else if t, ok := existing.(*oauth2.DefaultAccessToken); !ok {
return oauth2.FromAccessToken(t)
} else {
return t
}
}
func minTime(t1, t2 time.Time) time.Time {
if t1.IsZero() || t1.Before(t2) {
return t1
} else {
return t2
}
}
//func ConvertToUserAuthenticationWithPermissions(
// userAuth security.Authentication,
// account security.Account,
//) oauth2.UserAuthentication {
// principal, e := security.GetUsername(userAuth)
// if e != nil {
// principal = fmt.Sprintf("%v", userAuth)
// }
//
// details, ok := userAuth.Details().(map[string]interface{})
// if !ok {
// details = map[string]interface{}{
// "Literal": userAuth.Details(),
// }
// }
// permissions := make(map[string]interface{})
// for _, permission := range account.Permissions() {
// permissions[permission] = nil
// }
//
// return oauth2.NewUserAuthentication(func(opt *oauth2.UserAuthOption) {
// opt.Principal = principal
// opt.Permissions = permissions
// opt.State = userAuth.State()
// opt.Details = details
// })
//}
/*
***************************
Errors
***************************
*/
func newTamperedIDOrUsernameError(reasons ...interface{}) error {
return oauth2.NewInternalError("finalizer tampered with the ID or Username field", reasons...)
}
func newTamperedTenancyError(reasons ...interface{}) error {
return oauth2.NewInternalError("finalizer tampered with the tenancy of the account", reasons...)
}
func newImmutableContextError(reasons ...interface{}) error {
return oauth2.NewInternalError("context is not mutable", reasons...)
}
func newInvalidClientError(reasons ...interface{}) error {
return oauth2.NewInvalidGrantError("trying authroize with unknown client", reasons...)
}
func newInvalidTenantForClientError(reasons ...interface{}) error {
return oauth2.NewInvalidGrantError("authenticated client doesn't have access to the requested tenant", reasons...)
}
func newUnauthenticatedUserError(reasons ...interface{}) error {
return oauth2.NewInvalidGrantError("trying authroize with unauthenticated user", reasons...)
}
func newInvalidUserError(reasons ...interface{}) error {
return oauth2.NewInvalidGrantError("invalid authorizing user", reasons...)
}
func newInvalidTenantForUserError(reasons ...interface{}) error {
return oauth2.NewInvalidGrantError("authenticated user does not have access to the requested tenant", reasons...)
}
func newInvalidProviderError(reasons ...interface{}) error {
return oauth2.NewInvalidGrantError("authenticated user does not have access to the requested provider", reasons...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package token
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"net/http"
)
var (
FeatureId = security.FeatureId("OAuth2AuthToken", security.FeatureOrderOAuth2TokenEndpoint)
)
//goland:noinspection GoNameStartsWithPackageName
type TokenEndpointConfigurer struct {
}
func newOAuth2TokenEndpointConfigurer() *TokenEndpointConfigurer {
return &TokenEndpointConfigurer{
}
}
func (c *TokenEndpointConfigurer) Apply(feature security.Feature, ws security.WebSecurity) (err error) {
// Verify
f := feature.(*TokenFeature)
if err := c.validate(f, ws); err != nil {
return err
}
// prepare middlewares
tokenMw := NewTokenEndpointMiddleware(func(opts *TokenEndpointOptions) {
opts.Granter = auth.NewCompositeTokenGranter(f.granters...)
})
// install middlewares
tokenMapping := middleware.NewBuilder("token endpoint").
ApplyTo(matcher.RouteWithPattern(f.path, http.MethodPost)).
Order(security.MWOrderOAuth2Endpoints).
Use(tokenMw.TokenHandlerFunc())
ws.Add(tokenMapping)
// add dummy handler
ws.Add(mapping.Post(f.path).HandlerFunc(security.NoopHandlerFunc()))
return nil
}
func (c *TokenEndpointConfigurer) validate(f *TokenFeature, ws security.WebSecurity) error {
if f.path == "" {
return fmt.Errorf("token endpoint is not set")
}
if f.granters == nil || len(f.granters) == 0 {
return fmt.Errorf("token granters is not set")
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package token
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
)
// We currently don't have any stuff to configure
//goland:noinspection GoNameStartsWithPackageName
type TokenFeature struct {
path string
granters []auth.TokenGranter
}
// Standard security.Feature entrypoint
func (f *TokenFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
func Configure(ws security.WebSecurity) *TokenFeature {
feature := NewEndpoint()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*TokenFeature)
}
panic(fmt.Errorf("unable to configure oauth2 authserver: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func NewEndpoint() *TokenFeature {
return &TokenFeature{
}
}
/** Setters **/
func (f *TokenFeature) Path(path string) *TokenFeature {
f.path = path
return f
}
func (f *TokenFeature) AddGranter(granter auth.TokenGranter) *TokenFeature {
if composite, ok := granter.(*auth.CompositeTokenGranter); ok {
f.granters = append(f.granters, composite.Delegates()...)
} else {
f.granters = append(f.granters, granter)
}
return f
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package token
import (
"errors"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
"github.com/gin-gonic/gin"
)
/***********************
Token Endpoint
***********************/
//goland:noinspection GoNameStartsWithPackageName
type TokenEndpointMiddleware struct {
granter auth.TokenGranter
}
//goland:noinspection GoNameStartsWithPackageName
type TokenEndpointOptionsFunc func(*TokenEndpointOptions)
//goland:noinspection GoNameStartsWithPackageName
type TokenEndpointOptions struct {
Granter *auth.CompositeTokenGranter
}
func NewTokenEndpointMiddleware(optionFuncs...TokenEndpointOptionsFunc) *TokenEndpointMiddleware {
opts := TokenEndpointOptions{
Granter: auth.NewCompositeTokenGranter(),
}
for _, optFunc := range optionFuncs {
if optFunc != nil {
optFunc(&opts)
}
}
return &TokenEndpointMiddleware{
granter: opts.Granter,
}
}
func (mw *TokenEndpointMiddleware) TokenHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
// first we double check if client is authenticated
client := auth.RetrieveAuthenticatedClient(ctx)
if client == nil {
mw.handleError(ctx, oauth2.NewClientNotFoundError("invalid client"))
return
}
// parse request
tokenRequest, e := auth.ParseTokenRequest(ctx.Request)
if e != nil {
mw.handleError(ctx, oauth2.NewInvalidTokenRequestError("invalid token request", e))
return
}
// see if client id matches
if tokenRequest.ClientId != "" && tokenRequest.ClientId != client.ClientId() {
mw.handleError(ctx, oauth2.NewInvalidTokenRequestError("given client Domain does not match authenticated client"))
return
}
tokenRequest.Extensions[oauth2.ExtUseSessionTimeout] = client.UseSessionTimeout()
// check grant
if e := auth.ValidateGrant(ctx, client, tokenRequest.GrantType); e != nil {
mw.handleError(ctx, e)
return
}
// check if supported
if tokenRequest.GrantType == oauth2.GrantTypeImplicit {
mw.handleError(ctx, oauth2.NewInvalidGrantError("implicit grant type not supported from token endpoint"))
return
}
token, e := mw.granter.Grant(ctx, tokenRequest)
if e != nil {
mw.handleError(ctx, e)
return
}
mw.handleSuccess(ctx, token)
}
}
func (mw *TokenEndpointMiddleware) handleSuccess(c *gin.Context, v interface{}) {
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
c.JSON(200, v)
c.Abort()
}
func (mw *TokenEndpointMiddleware) handleError(c *gin.Context, err error) {
if errors.Is(err, oauth2.ErrorTypeOAuth2) {
err = oauth2.NewInvalidGrantError(err)
}
_ = c.Error(err)
c.Abort()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package token
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "oauth2 auth - token",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newOAuth2TokenEndpointConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils/order"
)
/*****************************
Abstraction
*****************************/
// TokenEnhancer modify given oauth2.AccessToken or return a new token based on given context and auth
// Most TokenEnhancer responsible to add/modify claims of given access token
// But it's not limited to do so. e.g. TokenEnhancer could be responsible to install refresh token
// Usually if given token is not mutable, the returned token would be different instance
type TokenEnhancer interface {
Enhance(ctx context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error)
}
/*****************************
Common Implementations
*****************************/
type CompositeTokenEnhancer struct {
delegates []TokenEnhancer
}
func NewCompositeTokenEnhancer(delegates ...TokenEnhancer) *CompositeTokenEnhancer {
return &CompositeTokenEnhancer{delegates: delegates}
}
func (e *CompositeTokenEnhancer) Enhance(ctx context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
for _, enhancer := range e.delegates {
current, err := enhancer.Enhance(ctx, token, oauth)
if err != nil {
return nil, err
}
token = current
}
return token, nil
}
func (e *CompositeTokenEnhancer) Add(enhancers... TokenEnhancer) {
e.delegates = append(e.delegates, flattenEnhancers(enhancers)...)
// resort the extensions
order.SortStable(e.delegates, order.OrderedFirstCompare)
}
func (e *CompositeTokenEnhancer) Remove(enhancer TokenEnhancer) {
for i, item := range e.delegates {
if item != enhancer {
continue
}
// remove but keep order
if i + 1 <= len(e.delegates) {
copy(e.delegates[i:], e.delegates[i+1:])
}
e.delegates = e.delegates[:len(e.delegates) - 1]
return
}
}
// flattenEnhancers recursively flatten any nested CompositeTokenEnhancer
func flattenEnhancers(enhancers []TokenEnhancer) (ret []TokenEnhancer) {
ret = make([]TokenEnhancer, 0, len(enhancers))
for _, e := range enhancers {
switch e.(type) {
case *CompositeTokenEnhancer:
flattened := flattenEnhancers(e.(*CompositeTokenEnhancer).delegates)
ret = append(ret, flattened...)
default:
ret = append(ret, e)
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/google/uuid"
"time"
)
const (
errTmplUnsupportedToken = `unsupported token implementation %T`
)
/*****************************
Expiry Time Enhancer
*****************************/
// ExpiryTokenEnhancer implements order.Ordered and TokenEnhancer
type ExpiryTokenEnhancer struct {}
func (e *ExpiryTokenEnhancer) Order() int {
return TokenEnhancerOrderExpiry
}
func (e *ExpiryTokenEnhancer) Enhance(_ context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError(errTmplUnsupportedToken, t)
}
if authDetails, ok := oauth.Details().(security.AuthenticationDetails); ok {
t.SetIssueTime(authDetails.IssueTime())
t.SetExpireTime(authDetails.ExpiryTime())
} else {
t.SetIssueTime(time.Now().UTC())
}
return t, nil
}
/*****************************
Details Enhancer
*****************************/
// DetailsTokenEnhancer implements order.Ordered and TokenEnhancer
// it populate token's additional metadata other than claims, issue/expiry time
type DetailsTokenEnhancer struct {}
func (e *DetailsTokenEnhancer) Order() int {
return TokenEnhancerOrderTokenDetails
}
func (e *DetailsTokenEnhancer) Enhance(_ context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError(errTmplUnsupportedToken, t)
}
t.SetScopes(oauth.OAuth2Request().Scopes())
return t, nil
}
/*****************************
BasicClaims Enhancer
*****************************/
// BasicClaimsTokenEnhancer impelments order.Ordered and TokenEnhancer
type BasicClaimsTokenEnhancer struct {
issuer security.Issuer
}
func (te *BasicClaimsTokenEnhancer) Order() int {
return TokenEnhancerOrderBasicClaims
}
func (te *BasicClaimsTokenEnhancer) Enhance(_ context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError(errTmplUnsupportedToken, t)
}
request := oauth.OAuth2Request()
basic := &oauth2.BasicClaims {
Id: uuid.New().String(),
Audience: oauth2.StringSetClaim(utils.NewStringSet(request.ClientId())),
Issuer: te.issuer.Identifier(),
ClientId: request.ClientId(),
Scopes: request.Scopes().Copy(),
}
if t.Claims() != nil && t.Claims().Has(oauth2.ClaimJwtId) {
basic.Id = t.Claims().Get(oauth2.ClaimJwtId).(string)
}
if oauth.UserAuthentication() != nil {
if sub, e := extractSubject(oauth.UserAuthentication()); e != nil {
return nil, e
} else {
basic.Subject = sub
}
}
if !t.ExpiryTime().IsZero() {
basic.ExpiresAt = t.ExpiryTime()
}
if !t.IssueTime().IsZero() {
basic.IssuedAt = t.IssueTime()
basic.NotBefore = t.IssueTime()
}
t.SetClaims(basic)
return t, nil
}
func extractSubject(auth security.Authentication) (string, error) {
p := auth.Principal()
switch p.(type) {
case string:
return p.(string), nil
case security.Account:
return p.(security.Account).Username(), nil
case fmt.Stringer:
return p.(fmt.Stringer).String(), nil
default:
return "", oauth2.NewInternalError("unable to extract subject for authentication %T", auth)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/claims"
)
/*****************************
legacyClaims Enhancer
*****************************/
// legacyClaims implements Claims and includes BasicClaims
type legacyClaims struct {
oauth2.FieldClaimsMapper
*oauth2.BasicClaims
FirstName string `claim:"firstName"`
LastName string `claim:"lastName"`
Email string `claim:"email"`
TenantId string `claim:"tenantId"`
Username string `claim:"user_name"`
Roles []string `claim:"roles"`
}
func (c *legacyClaims) MarshalJSON() ([]byte, error) {
return c.FieldClaimsMapper.DoMarshalJSON(c)
}
func (c *legacyClaims) UnmarshalJSON(bytes []byte) error {
return c.FieldClaimsMapper.DoUnmarshalJSON(c, bytes)
}
func (c *legacyClaims) Get(claim string) interface{} {
return c.FieldClaimsMapper.Get(c, claim)
}
func (c *legacyClaims) Has(claim string) bool {
return c.FieldClaimsMapper.Has(c, claim)
}
func (c *legacyClaims) Set(claim string, value interface{}) {
c.FieldClaimsMapper.Set(c, claim, value)
}
func (c *legacyClaims) Values() map[string]interface{} {
return c.FieldClaimsMapper.Values(c)
}
// LegacyTokenEnhancer implements order.Ordered and TokenEnhancer
// LegacyTokenEnhancer add legacy claims and response fields that was supported by Java version of IDM
// but deprecated in Go version
type LegacyTokenEnhancer struct{}
func NewLegacyTokenEnhancer() TokenEnhancer {
return &LegacyTokenEnhancer{}
}
func (te *LegacyTokenEnhancer) Order() int {
return TokenEnhancerOrderDetailsClaims
}
func (te *LegacyTokenEnhancer) Enhance(_ context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError("unsupported token implementation %T", t)
}
if t.Claims() == nil {
return nil, oauth2.NewInternalError("LegacyTokenEnhancer need to be placed immediately after BasicClaimsEnhancer")
}
basic, ok := t.Claims().(*oauth2.BasicClaims)
if !ok {
return nil, oauth2.NewInternalError("LegacyTokenEnhancer need to be placed immediately after BasicClaimsEnhancer")
}
legacy := &legacyClaims{
BasicClaims: basic,
Username: basic.Subject,
}
t.PutDetails(oauth2.ClaimUsername, legacy.Username)
if ud, ok := oauth.Details().(security.UserDetails); ok {
legacy.FirstName = ud.FirstName()
legacy.LastName = ud.LastName()
legacy.Email = ud.Email()
}
if td, ok := oauth.Details().(security.TenantDetails); ok {
legacy.TenantId = td.TenantId()
t.PutDetails(oauth2.ClaimLegacyTenantId, td.TenantId())
}
if ad, ok := oauth.Details().(security.AuthenticationDetails); ok {
legacy.Roles = ad.Roles().Values()
t.PutDetails(oauth2.ClaimRoles, legacy.Roles)
}
t.SetClaims(legacy)
return t, nil
}
// ResourceIdTokenEnhancer impelments order.Ordered and TokenEnhancer
// spring-security-oauth2 based java implementation expecting "aud" claims to be the resource ID
type ResourceIdTokenEnhancer struct {
}
func (te *ResourceIdTokenEnhancer) Order() int {
return TokenEnhancerOrderResourceIdClaims
}
func (te *ResourceIdTokenEnhancer) Enhance(c context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError("unsupported token implementation %T", t)
}
if t.Claims() == nil || !t.Claims().Has(oauth2.ClaimAudience) {
return nil, oauth2.NewInternalError("ResourceIdTokenEnhancer need to be placed after BasicClaimsEnhancer")
}
aud := claims.LegacyAudience(c, &claims.FactoryOption{
Source: oauth,
})
t.Claims().Set(oauth2.ClaimAudience, aud)
return t, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/google/uuid"
)
var (
refreshTokenAllowedGrants = utils.NewStringSet(
oauth2.GrantTypeAuthCode,
oauth2.GrantTypeImplicit,
oauth2.GrantTypeRefresh,
oauth2.GrantTypeSwitchTenant, // Need this to create a new refresh token when switching tenants
//oauth2.GrantTypePassword, // this is for dev purpose, shouldn't be allowed
)
)
/*****************************
RefreshToken Enhancer
*****************************/
// RefreshTokenEnhancer implements order.Ordered and TokenEnhancer
// RefreshTokenEnhancer is responsible to create refresh token and associate it with the given access token
type RefreshTokenEnhancer struct {
tokenStore TokenStore
issuer security.Issuer
}
func (te *RefreshTokenEnhancer) Order() int {
return TokenEnhancerOrderRefreshToken
}
func (te *RefreshTokenEnhancer) Enhance(ctx context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
// step 1, check if refresh token is needed
client, ok := ctx.Value(oauth2.CtxKeyAuthenticatedClient).(oauth2.OAuth2Client)
if !ok || !te.isRefreshTokenNeeded(ctx, token, oauth, client) {
return token, nil
}
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError("unsupported token implementation %T", t)
}
// step 2, create refresh token
// Note: we don't reuse refresh token
id := uuid.New().String()
refresh := oauth2.NewDefaultRefreshToken(id)
// step 3, set expriy time
// Note: refresh token's validity is counted since authentication time
details, ok := oauth.Details().(security.AuthenticationDetails)
if ok && client.RefreshTokenValidity() > 0 && !details.AuthenticationTime().IsZero() {
expiry := details.AuthenticationTime().Add(client.RefreshTokenValidity())
refresh.SetExpireTime(expiry)
}
// step 4 create claims,
request := oauth.OAuth2Request()
claims := oauth2.BasicClaims{
Id: id,
Audience: oauth2.StringSetClaim(utils.NewStringSet(client.ClientId())),
Issuer: te.issuer.Identifier(),
Scopes: request.Scopes(),
}
if oauth.UserAuthentication() != nil {
if sub, e := extractSubject(oauth.UserAuthentication()); e != nil {
return nil, e
} else {
claims.Subject = sub
}
}
if refresh.WillExpire() && !refresh.ExpiryTime().IsZero() {
claims.Set(oauth2.ClaimExpire, refresh.ExpiryTime())
}
refresh.SetClaims(&claims)
// step 5, save refresh token
if saved, e := te.tokenStore.SaveRefreshToken(ctx, refresh, oauth); e == nil {
t.SetRefreshToken(saved)
}
return t, nil
}
/*****************************
Helpers
*****************************/
func (te *RefreshTokenEnhancer) isRefreshTokenNeeded(ctx context.Context, token oauth2.AccessToken, oauth oauth2.Authentication, client oauth2.OAuth2Client) bool {
// refresh grant should be allowed for the client
if e := ValidateGrant(ctx, client, oauth2.GrantTypeRefresh); e != nil {
return false
}
// only some grant types can return refresh token
if !refreshTokenAllowedGrants.Has(oauth.OAuth2Request().GrantType()) {
return false
}
// last, if given token already have an refresh token, no need to generate new
return token.RefreshToken() == nil || token.RefreshToken().WillExpire() && token.RefreshToken().Expired()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"net/http"
)
type TokenRequest struct {
Parameters map[string]string
ClientId string
Scopes utils.StringSet
GrantType string
Extensions map[string]interface{}
context utils.MutableContext
}
func (r *TokenRequest) Context() utils.MutableContext {
return r.context
}
func (r *TokenRequest) WithContext(ctx context.Context) *TokenRequest {
r.context = utils.MakeMutableContext(ctx)
return r
}
func (r *TokenRequest) OAuth2Request(client oauth2.OAuth2Client) oauth2.OAuth2Request {
return oauth2.NewOAuth2Request(func(details *oauth2.RequestDetails) {
details.Parameters = r.Parameters
details.ClientId = client.ClientId()
details.Scopes = r.Scopes
details.Approved = true
details.GrantType = r.GrantType
details.Extensions = r.Extensions
})
}
func NewTokenRequest() *TokenRequest {
return &TokenRequest{
Parameters: map[string]string{},
Scopes: utils.NewStringSet(),
Extensions: map[string]interface{}{},
context: utils.NewMutableContext(context.Background()),
}
}
func ParseTokenRequest(req *http.Request) (*TokenRequest, error) {
if err := req.ParseForm(); err != nil {
return nil, err
}
values := flattenValuesToMap(req.Form);
return &TokenRequest{
Parameters: toStringMap(values),
ClientId: extractStringParam(oauth2.ParameterClientId, values),
Scopes: extractStringSetParam(oauth2.ParameterScope, " ", values),
GrantType: extractStringParam(oauth2.ParameterGrantType, values),
Extensions: values,
context: utils.MakeMutableContext(req.Context()),
}, nil
}
func (r *TokenRequest) String() string {
return fmt.Sprintf("[client=%s, grant=%s, scope=%s, ext=%s]",
r.ClientId, r.GrantType, r.Scopes, r.Extensions)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/common"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
)
// jwtTokenStore implements TokenStore and delegate oauth2.TokenStoreReader portion to embedded interface
type jwtTokenStore struct {
oauth2.TokenStoreReader
detailsStore security.ContextDetailsStore
jwtEncoder jwt.JwtEncoder
registry AuthorizationRegistry
}
type JTSOptions func(opt *JTSOption)
type JTSOption struct {
Reader oauth2.TokenStoreReader
DetailsStore security.ContextDetailsStore
Encoder jwt.JwtEncoder
Decoder jwt.JwtDecoder
AuthRegistry AuthorizationRegistry
}
func NewJwtTokenStore(opts...JTSOptions) *jwtTokenStore {
opt := JTSOption{}
for _, optFunc := range opts {
optFunc(&opt)
}
if opt.Reader == nil {
opt.Reader = common.NewJwtTokenStoreReader(func(o *common.JTSROption) {
o.DetailsStore = opt.DetailsStore
o.Decoder = opt.Decoder
})
}
return &jwtTokenStore{
TokenStoreReader: opt.Reader,
detailsStore: opt.DetailsStore,
jwtEncoder: opt.Encoder,
registry: opt.AuthRegistry,
}
}
func (s *jwtTokenStore) ReadAuthentication(ctx context.Context, tokenValue string, hint oauth2.TokenHint) (oauth2.Authentication, error) {
switch hint {
case oauth2.TokenHintRefreshToken:
return s.readAuthenticationFromRefreshToken(ctx, tokenValue)
default:
return s.TokenStoreReader.ReadAuthentication(ctx, tokenValue, hint)
}
}
func (s *jwtTokenStore) ReusableAccessToken(_ context.Context, _ oauth2.Authentication) (oauth2.AccessToken, error) {
// JWT don't reuse access token
return nil, nil
}
func (s *jwtTokenStore) SaveAccessToken(c context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) (oauth2.AccessToken, error) {
t, ok := token.(*oauth2.DefaultAccessToken)
if !ok {
return nil, oauth2.NewInternalError(fmt.Sprintf("Unsupported token implementation [%T]", token))
} else if t.Claims() == nil {
return nil, oauth2.NewInternalError("claims is nil")
}
encoded, e := s.jwtEncoder.Encode(c, t.Claims())
if e != nil {
return nil, e
}
t.SetValue(encoded)
if details, ok := oauth.Details().(security.ContextDetails); ok {
if e := s.detailsStore.SaveContextDetails(c, token, details); e != nil {
return nil, oauth2.NewInternalError("cannot save access token", e)
}
}
if e := s.registry.RegisterAccessToken(c, t, oauth); e != nil {
return nil, oauth2.NewInternalError("cannot register access token", e)
}
return t, nil
}
func (s *jwtTokenStore) SaveRefreshToken(c context.Context, token oauth2.RefreshToken, oauth oauth2.Authentication) (oauth2.RefreshToken, error) {
t, ok := token.(*oauth2.DefaultRefreshToken)
if !ok {
return nil, fmt.Errorf("Unsupported token implementation [%T]", token)
} else if t.Claims() == nil {
return nil, fmt.Errorf("claims is nil")
}
encoded, e := s.jwtEncoder.Encode(c, t.Claims())
if e != nil {
return nil, e
}
t.SetValue(encoded)
if e := s.registry.RegisterRefreshToken(c, t, oauth); e != nil {
return nil, oauth2.NewInternalError("cannot register refresh token", e)
}
return t, nil
}
func (s *jwtTokenStore) RemoveAccessToken(c context.Context, token oauth2.Token) error {
switch t := token.(type) {
case oauth2.AccessToken:
// just remove access token
return s.registry.RevokeAccessToken(c, t)
case oauth2.RefreshToken:
// remove all access token associated with this refresh token
return s.registry.RevokeAllAccessTokens(c, t)
}
return nil
}
func (s *jwtTokenStore) RemoveRefreshToken(c context.Context, token oauth2.RefreshToken) error {
// remove all access token associated with this refresh token and refresh token itself
return s.registry.RevokeRefreshToken(c, token)
}
/********************
Helpers
********************/
func (s *jwtTokenStore) readAuthenticationFromRefreshToken(c context.Context, tokenValue string) (oauth2.Authentication, error) {
// parse JWT token
token, e := s.ReadRefreshToken(c, tokenValue)
if e != nil {
return nil, e
}
if container, ok := token.(oauth2.ClaimsContainer); !ok || container.Claims() == nil {
return nil, oauth2.NewInvalidGrantError("refresh token contains no claims")
}
stored, e := s.registry.ReadStoredAuthorization(c, token)
if e != nil {
return nil, oauth2.NewInvalidGrantError("refresh token unknown", e)
}
return stored, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"fmt"
"github.com/bmatcuk/doublestar/v4"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"net/url"
"regexp"
"strings"
)
const (
pScheme = `(?P<scheme>[a-z][a-z0-9+\-.]*)`
pUserInfo = `(?P<userinfo>[^@]*)`
pDomain = `(?P<domain>([a-zA-Z0-9_\-*?]+\.)*([a-zA-Z0-9_\-*?]{1,11}))`
pPort = `(?P<port>[0-9*?]{1,5})`
pPath = `(?P<path>\/?[^?#:]*)`
pParams = `(?P<params>[^#]*)`
pFragment = `(?P<fragment>.*)`
)
var (
// Warning: if pattern with custom scheme is provided, it's required to add "/" after ":".
// e.g. "custom-scheme:/some_path" is a valid pattern, but "custom-scheme:some_path" is not
pUrl = fmt.Sprintf(`^(%s:[/]{1,2})?(%s@)?(%s(:%s)?)?%s(\?%s)?(#%s)?`,
pScheme, pUserInfo, pDomain, pPort, pPath, pParams, pFragment)
regexWildcardPattern = regexp.MustCompile(pUrl)
regexQueryParamsPattern = regexp.MustCompile(`(?P<key>[^&=]+)(?P<eq>=?)(?P<value>[^&]+)?`)
)
/*****************************
Public
*****************************/
// wildcardUrlMatcher implements matcher.Matcher, matcher.ChainableMatcher and fmt.Stringer
// it accept escaped URL string and matches with the defined pattern allowing wildcard * and ? in
// domain, port, and path
type wildcardUrlMatcher struct {
raw string
patterns
}
type patterns struct {
scheme string
userInfo string
domain string
port string
path string
params map[string][]string
fragment string
}
// NewWildcardUrlMatcher construct a wildcard URL matcher with given pattern
// The pattern should be escaped for URL endoding
func NewWildcardUrlMatcher(pattern string) (*wildcardUrlMatcher, error) {
m := wildcardUrlMatcher{
raw: pattern,
}
if e := parsePatterns(pattern, &m.patterns); e != nil {
return nil, e
}
return &m, nil
}
func (m *wildcardUrlMatcher) Matches(i interface{}) (bool, error) {
switch i.(type) {
case string:
return m.urlMatches(i.(string))
default:
return false, fmt.Errorf("unsupported URL with type [%T]", i)
}
}
func (m *wildcardUrlMatcher) MatchesWithContext(_ context.Context, i interface{}) (bool, error) {
return m.Matches(i)
}
func (m *wildcardUrlMatcher) Or(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.Or(m, matchers...)
}
func (m *wildcardUrlMatcher) And(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.And(m, matchers...)
}
func (m *wildcardUrlMatcher) String() string {
return fmt.Sprintf("matches pattern %s", m.raw)
}
/*****************************
Helpers
*****************************/
func (m *wildcardUrlMatcher) urlMatches(urlStr string) (bool, error) {
url, e := url.Parse(urlStr)
if e != nil {
return false, e
}
// if scheme is given and we cannot map it to valid port, we consider it as custom scheme
if url.Scheme != "" && schemeToPort(url.Scheme) == "" {
// for custom scheme, we perform exact match
return exactMatches(urlStr, m.raw, true), nil
}
// exact matches
mScheme := exactMatches(url.Scheme, m.scheme, false)
mUserInfo := exactMatches(url.User, m.userInfo, false)
mQuery := queryParamsMatches(url.Query(), m.params)
// wildcard matches
mHost := hostMatches(url.Hostname(), m.domain)
mPort := portMatches(url.Scheme, url.Port(), m.scheme, m.port, m.domain)
mPath := pathMatches(url.Path, m.path)
return mScheme && mUserInfo && mQuery && mHost && mPort && mPath, nil
}
func parsePatterns(raw string, dst *patterns) error {
// parse overall pattern
matches := regexWildcardPattern.FindStringSubmatch(raw);
if matches == nil {
return fmt.Errorf("invalid pattern %s", raw)
}
components := map[string]string{}
for i, group := range regexWildcardPattern.SubexpNames() {
components[group] = strings.TrimSpace(matches[i])
}
dst.scheme = components["scheme"]
dst.userInfo = components["userinfo"]
dst.domain = components["domain"]
dst.port = components["port"]
dst.path = components["path"]
dst.fragment = components["fragment"]
// parse query params
dst.params = map[string][]string{}
all := regexQueryParamsPattern.FindAllStringSubmatch(components["params"], -1)
if all == nil {
// no params to parse, ok
return nil
}
for _, one := range all {
for i, group := range regexQueryParamsPattern.SubexpNames() {
dst.params[group] = append(dst.params[group], strings.TrimSpace(one[i]))
}
}
return nil
}
// exactMatches check if string representation of given value exactly matches pattern.
// If pattern is empty:
// 1. ignore value and return true, if required == false
// 2. return true only when value is empty, if required == true
// accepted value are string or *url.UserInfo
func exactMatches(value interface{}, pattern string, required bool) bool {
actual := ""
switch value.(type) {
case nil:
// empty string
case string:
actual = value.(string)
case *url.Userinfo:
actual = value.(*url.Userinfo).String()
default:
return false
}
return (pattern == "" && !required) ||
actual == "" && pattern == "" ||
pattern != "" && pattern == actual
}
// queryParamsMatches checks whether the pattern query params key and values contains match the actual set
// The actual query params are allowed to contain additional params which will be retained
func queryParamsMatches(query url.Values, pattern map[string][]string) (ret bool) {
if query == nil {
query = url.Values{}
}
ret = true
for k, expected := range pattern {
actual, ok := query[k]
if !ok {
continue
}
ret = sliceEquals(actual, expected)
if !ret {
return
}
}
return
}
// hostMatches host matches allows sub domain to match as well.
// the function returns true if pattern is not set (empty string)
func hostMatches(value, pattern string) bool {
if !hasWildcard(pattern) {
return pattern == "" || pattern == value || strings.HasSuffix(value, "." + pattern)
}
return wildcardMatches(value, pattern, true)
}
// Check whether the requested port value matches the expected values.
// Special handling is required since the scheme of a expected values is optional in this implementation.
//
// When the patterns for the expected url does not specify a port value, the port value is
// inferred from the scheme of the registered redirect. If that is not specified (which should match
// any scheme):
// 1. if domain pattern is specified, the port is inferred based on the scheme of the <EM>requested</EM> URL (which
// will match unless the requested URL is using a non-standard port)
// 2. if domain pattern is also not set, the port matches any value
//
// The expected patterns may contain a wildcard for the port value.
func portMatches(scheme, port, schemePattern, portPattern, domainPattern string) bool {
expectedPort := portPattern
if portPattern == "" {
switch {
case schemePattern == "" && domainPattern == "":
// path-only pattern, match any value
return true
case schemePattern == "":
// domain pattern is specified, port should match scheme
expectedPort = schemeToPort(scheme)
default:
// scheme pattern is specified, use it
expectedPort = schemeToPort(schemePattern)
}
}
// Implied ports must be made explicit for matching - an empty string will not match a * wildcard!
if port == "" {
port = schemeToPort(scheme)
}
return wildcardMatches(port, expectedPort, true)
}
// pathMatches matches given path value to pattern with wildcoard support
func pathMatches(value, pattern string) bool {
if value == "" {
value = "/"
}
if pattern == "" {
pattern = "/"
}
return wildcardMatches(value, pattern, true)
}
// sliceEquals compares two slice and return true if:
// - both slices have same length
// - slice s1 contains all elements of slice s2
func sliceEquals(s1, s2 []string) bool {
if s1 == nil || s2 == nil || len(s1) != len(s2) {
return false
}
set := utils.NewStringSet(s1...)
for _, v := range s2 {
if !set.Has(v) {
return false
}
}
return true
}
func hasWildcard(pattern string) bool {
return strings.ContainsAny(pattern, "*?\\[]")
}
// wildcardMatches given string with pattern
// The prefix syntax is:
//
// prefix:
// { term }
// term:
// '*' matches any sequence of non-path-separators
// '**' matches any sequence of characters, including
// path separators.
// '?' matches any single non-path-separator character
// '[' [ '^' ] { character-range } ']'
// character class (must be non-empty)
// '{' { term } [ ',' { term } ... ] '}'
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
func wildcardMatches(value, pattern string, required bool) bool {
if pattern == "" {
return !required || value == ""
}
ok, e := doublestar.Match(pattern, value)
return e == nil && ok
}
func schemeToPort(scheme string) string {
switch scheme {
case "http":
return "80"
case "https":
return "443"
default:
return ""
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package common
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/utils"
)
/*
ResolveClientUserTenants will take the client's assigned tenants and the user's assigned tenants, and use them to compute the tenants
this security context has access to as a result. For example, if a client is assigned to tenant-1, it means anyone using this client
has access to tenant-1. If a user who has access to tenant-1 and tenant-2 is authenticated using this client. Then the resulting
security context should indicate that the user has only access to tenant-1. As a result, the user's default tenant may or may not
still be valid, so this method also returns that.
*/
func ResolveClientUserTenants(ctx context.Context, a security.Account, c oauth2.OAuth2Client) (defaultTenantId string, assignedTenants []string, err error) {
// client only
if a == nil {
assignedTenants = c.AssignedTenantIds().Values()
if len(assignedTenants) == 1 {
defaultTenantId = assignedTenants[0]
}
return defaultTenantId, assignedTenants, nil
}
at, ok := a.(security.AccountTenancy)
if !ok {
return "", nil, errors.New("account must have tenancy")
}
// To get the intersection of client and user's tenants
// we need to do two loops.
// First loop through the account's tenant.
// If this tenant is any of the client's tenant's descendant, we add it to the return set.
// Then loop through the client's tenant.
// If this tenant is any of the account's tenant's descendant, we add it to the return set.
tenantSet := utils.NewStringSet()
if c.AssignedTenantIds().Has(security.SpecialTenantIdWildcard) {
tenantSet = tenantSet.Add(at.DesignatedTenantIds()...)
} else {
for _, t := range at.DesignatedTenantIds() {
if tenancy.AnyHasDescendant(ctx, c.AssignedTenantIds(), t) {
tenantSet = tenantSet.Add(t)
}
}
}
if contains(at.DesignatedTenantIds(), security.SpecialTenantIdWildcard) {
tenantSet = tenantSet.Add(c.AssignedTenantIds().Values()...)
} else {
for t, _ := range c.AssignedTenantIds() {
if tenancy.AnyHasDescendant(ctx,
utils.NewStringSet(at.DesignatedTenantIds()...), t) {
tenantSet = tenantSet.Add(t)
}
}
}
if tenantSet.Has(security.SpecialTenantIdWildcard) ||
tenancy.AnyHasDescendant(ctx, tenantSet, a.(security.AccountTenancy).DefaultDesignatedTenantId()) {
defaultTenantId = a.(security.AccountTenancy).DefaultDesignatedTenantId()
}
assignedTenants = tenantSet.Values()
return defaultTenantId, assignedTenants, nil
}
func contains[T comparable](slice []T, item T) bool {
for i := range slice {
if slice[i] == item {
return true
}
}
return false
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package common
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/common/internal"
"github.com/cisco-open/go-lanai/pkg/utils"
"strings"
"time"
)
type FactoryOptions func(option *FactoryOption)
type FactoryOption struct {
}
type ContextDetailsFactory struct {
}
func NewContextDetailsFactory(opts ...FactoryOptions) *ContextDetailsFactory {
opt := FactoryOption{}
for _, f := range opts {
f(&opt)
}
return &ContextDetailsFactory{}
}
type facts struct {
request oauth2.OAuth2Request
client oauth2.OAuth2Client
account security.Account
tenant *security.Tenant
provider *security.Provider
userAuth oauth2.UserAuthentication
issueTime time.Time
expriyTime time.Time
authTime time.Time
source oauth2.Authentication
}
func (f *ContextDetailsFactory) New(ctx context.Context, request oauth2.OAuth2Request) (security.ContextDetails, error) {
facts := f.loadFacts(ctx, request)
// The auth only have client
if facts.account == nil {
return f.createSimple(ctx, facts)
}
// The auth has both client and user
// creates either the ClientUserContextDetail or ClientUserTenantedContextDetail
return f.create(ctx, facts)
}
/*
*********************
Helpers
*********************
*/
func (f *ContextDetailsFactory) loadFacts(ctx context.Context, request oauth2.OAuth2Request) *facts {
facts := facts{
request: request,
client: ctx.Value(oauth2.CtxKeyAuthenticatedClient).(oauth2.OAuth2Client),
}
if ctx.Value(oauth2.CtxKeyAuthenticatedAccount) != nil {
facts.account = ctx.Value(oauth2.CtxKeyAuthenticatedAccount).(security.Account)
}
if ctx.Value(oauth2.CtxKeyAuthorizedTenant) != nil {
facts.tenant = ctx.Value(oauth2.CtxKeyAuthorizedTenant).(*security.Tenant)
}
if ctx.Value(oauth2.CtxKeyAuthorizedProvider) != nil {
facts.provider = ctx.Value(oauth2.CtxKeyAuthorizedProvider).(*security.Provider)
}
if ctx.Value(oauth2.CtxKeyUserAuthentication) != nil {
facts.userAuth = ctx.Value(oauth2.CtxKeyUserAuthentication).(oauth2.UserAuthentication)
}
if ctx.Value(oauth2.CtxKeyAuthorizationIssueTime) != nil {
facts.issueTime = ctx.Value(oauth2.CtxKeyAuthorizationIssueTime).(time.Time)
} else {
facts.issueTime = time.Now()
}
if ctx.Value(oauth2.CtxKeyAuthorizationExpiryTime) != nil {
facts.expriyTime = ctx.Value(oauth2.CtxKeyAuthorizationExpiryTime).(time.Time)
}
if ctx.Value(oauth2.CtxKeyAuthenticationTime) != nil {
facts.authTime = ctx.Value(oauth2.CtxKeyAuthenticationTime).(time.Time)
} else {
facts.authTime = facts.issueTime
}
if ctx.Value(oauth2.CtxKeySourceAuthentication) != nil {
facts.source = ctx.Value(oauth2.CtxKeySourceAuthentication).(oauth2.Authentication)
}
return &facts
}
func (f *ContextDetailsFactory) create(ctx context.Context, facts *facts) (security.ContextDetails, error) {
// user
ud := internal.UserDetails{
Id: facts.account.ID().(string),
Username: facts.account.Username(),
AccountType: facts.account.Type(),
AssignedTenantIds: utils.NewStringSet(facts.account.(security.AccountTenancy).DesignatedTenantIds()...),
}
if meta, ok := facts.account.(security.AccountMetadata); ok {
ud.FirstName = meta.FirstName()
ud.LastName = meta.LastName()
ud.Email = meta.Email()
ud.LocaleCode = meta.LocaleCode()
ud.CurrencyCode = meta.CurrencyCode()
}
var cd internal.ClientDetails
if facts.client != nil {
cd = internal.ClientDetails{
Id: facts.client.ClientId(),
Scopes: facts.client.Scopes(),
AssignedTenantIds: facts.client.AssignedTenantIds(),
}
}
// auth details
ad, e := f.createAuthDetails(ctx, facts)
if e != nil {
return nil, e
}
_, assignedTenantId, e := ResolveClientUserTenants(ctx, facts.account, facts.client)
if e != nil {
return nil, e
}
if facts.tenant != nil {
// provider
pd := internal.ProviderDetails{
Id: facts.provider.Id,
Name: facts.provider.Name,
DisplayName: facts.provider.DisplayName,
Description: facts.provider.Description,
Email: facts.provider.Email,
NotificationType: facts.provider.NotificationType,
}
td := internal.TenantDetails{
Id: facts.tenant.Id,
ExternalId: facts.tenant.ExternalId,
Suspended: facts.tenant.Suspended,
}
return &internal.ClientUserTenantedContextDetails{
ClientUserContextDetails: internal.ClientUserContextDetails{
User: ud,
Client: cd,
Authentication: *ad,
KV: f.createKVDetails(ctx, facts),
TenantAccess: internal.TenantAccessDetails{
EffectiveAssignedTenantIds: utils.NewStringSet(assignedTenantId...),
},
},
Provider: pd,
Tenant: td,
}, nil
} else {
return &internal.ClientUserContextDetails{
User: ud,
Client: cd,
Authentication: *ad,
KV: f.createKVDetails(ctx, facts),
TenantAccess: internal.TenantAccessDetails{
EffectiveAssignedTenantIds: utils.NewStringSet(assignedTenantId...),
},
}, nil
}
}
func (f *ContextDetailsFactory) createSimple(ctx context.Context, facts *facts) (security.ContextDetails, error) {
ad, e := f.createAuthDetails(ctx, facts)
if e != nil {
return nil, e
}
cd := internal.ClientDetails{
Id: facts.client.ClientId(),
Scopes: facts.client.Scopes(),
AssignedTenantIds: facts.client.AssignedTenantIds(),
}
if facts.tenant != nil {
td := internal.TenantDetails{
Id: facts.tenant.Id,
ExternalId: facts.tenant.ExternalId,
Suspended: facts.tenant.Suspended,
}
pd := internal.ProviderDetails{
Id: facts.provider.Id,
Name: facts.provider.Name,
DisplayName: facts.provider.DisplayName,
Description: facts.provider.Description,
NotificationType: facts.provider.NotificationType,
Email: facts.provider.Email,
}
return &internal.ClientTenantedContextDetails{
ClientContextDetails: internal.ClientContextDetails{
Authentication: *ad,
KV: f.createKVDetails(ctx, facts),
Client: cd,
},
Tenant: td,
Provider: pd,
}, nil
} else {
return &internal.ClientContextDetails{
Authentication: *ad,
KV: f.createKVDetails(ctx, facts),
Client: cd,
}, nil
}
}
func (f *ContextDetailsFactory) createAuthDetails(ctx context.Context, facts *facts) (*internal.AuthenticationDetails, error) {
d := internal.AuthenticationDetails{}
if facts.account != nil {
d.Permissions = utils.NewStringSet(facts.account.Permissions()...)
if meta, ok := facts.account.(security.AccountMetadata); ok {
d.Roles = utils.NewStringSet(meta.RoleNames()...)
}
} else {
d.Roles = utils.NewStringSet()
d.Permissions = facts.request.Scopes()
}
d.AuthenticationTime = facts.authTime
d.IssueTime = facts.issueTime
d.ExpiryTime = facts.expriyTime
f.populateProxyDetails(ctx, &d, facts)
return &d, nil
}
func (f *ContextDetailsFactory) populateProxyDetails(_ context.Context, d *internal.AuthenticationDetails, facts *facts) {
if facts.source == nil {
return
}
if proxyDetails, ok := facts.source.Details().(security.ProxiedUserDetails); ok && proxyDetails.Proxied() {
// original details is proxied
d.Proxied = true
d.OriginalUsername = proxyDetails.OriginalUsername()
return
}
src, ok := facts.source.Details().(security.UserDetails)
if !ok {
return
}
if facts.account == nil || strings.TrimSpace(facts.account.Username()) != strings.TrimSpace(src.Username()) {
d.Proxied = true
d.OriginalUsername = strings.TrimSpace(src.Username())
}
}
func (f *ContextDetailsFactory) createKVDetails(_ context.Context, facts *facts) (ret map[string]interface{}) {
ret = map[string]interface{}{}
if facts.userAuth != nil {
if sid, ok := facts.userAuth.DetailsMap()[security.DetailsKeySessionId]; ok {
ret[security.DetailsKeySessionId] = sid
}
}
if facts.request != nil {
ret[oauth2.DetailsKeyRequestExt] = facts.request.Extensions()
//ret[oauth2.DetailsKeyRequestParams] = facts.request.Parameters()
}
if facts.source == nil {
return
}
if srcKV, ok := facts.source.Details().(security.KeyValueDetails); ok {
for k, v := range srcKV.Values() {
ret[k] = v
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package common
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/common/internal"
"time"
)
const (
redisDB = 13
prefixAccessTokenToDetails = "AAT"
prefixRefreshTokenToAuthentication = "ART"
prefixAccessTokenFromUserAndClient = "AUC"
prefixRefreshTokenFromUserAndClient = "RUC"
prefixAccessFromRefreshToken = "AR"
prefixRefreshTokenFromSessionId = "RS"
prefixAccessTokenFromSessionId = "AS"
/*
Original comment form Java implementation:
When specific token of a client is used, we look up the session and update
its last requested time
These records should have an expiry time equal to the token's expiry time
*/
// those relationships are not needed anymore, because session ID is stored in UserAuthentication and context details
// which are mapped from refresh and access token
//prefixRefreshTokenToSessionId = "R_TO_S"
//prefixAccessTokenToSessionId = "A_TO_S"
/*
* Original comment form Java implementation:
* We also want to store the original OAuth2 Request, because JWT token doesn't carry all information
* from OAuth2 request (we don't want super long JWT). We don't want to carry it in SecurityContextDetails
* because original OAuth2 request is only needed by authorization server
*/
// Those relationships are not needed anymore, because addtional details such as session ID is now carried in
// security.KeyValueDetails
//prefixAccessTokenToRequest = "ORAT"
//prefixRefreshTokenToRequest = "ORRT"
//prefix = ""
)
const (
errTmplUnsupportedKey = `unsupported key type %T`
errTmplUnsupportedDetails = `unsupported details type %T`
)
// RedisContextDetailsStore implements security.ContextDetailsStore and auth.AuthorizationRegistry
type RedisContextDetailsStore struct {
vTag string
client redis.Client
timeoutApplier oauth2.TimeoutApplier
}
func NewRedisContextDetailsStore(ctx context.Context, cf redis.ClientFactory, timeoutApplier oauth2.TimeoutApplier) *RedisContextDetailsStore {
client, e := cf.New(ctx, func(opt *redis.ClientOption) {
opt.DbIndex = redisDB
})
if e != nil {
panic(e)
}
return &RedisContextDetailsStore{
vTag: security.CompatibilityReference,
client: client,
timeoutApplier: timeoutApplier,
}
}
/**********************************
security.ContextDetailsStore
**********************************/
func (r *RedisContextDetailsStore) ReadContextDetails(c context.Context, key interface{}) (security.ContextDetails, error) {
switch t := key.(type) {
case oauth2.AccessToken:
return r.loadDetailsFromAccessToken(c, t)
default:
return nil, fmt.Errorf(errTmplUnsupportedKey, key)
}
}
func (r *RedisContextDetailsStore) SaveContextDetails(c context.Context, key interface{}, details security.ContextDetails) error {
switch details.(type) {
case *internal.ClientUserContextDetails:
case *internal.ClientContextDetails:
case *internal.ClientUserTenantedContextDetails:
case *internal.ClientTenantedContextDetails:
default:
return fmt.Errorf(errTmplUnsupportedDetails, details)
}
switch t := key.(type) {
case oauth2.AccessToken:
return r.saveAccessTokenToDetails(c, t, details)
default:
return fmt.Errorf(errTmplUnsupportedKey, key)
}
}
func (r *RedisContextDetailsStore) RemoveContextDetails(c context.Context, key interface{}) error {
switch t := key.(type) {
case oauth2.AccessToken:
_, e := r.doRemoveDetials(c, t, "")
return e
default:
return fmt.Errorf(errTmplUnsupportedKey, key)
}
}
func (r *RedisContextDetailsStore) ContextDetailsExists(c context.Context, key interface{}) bool {
switch t := key.(type) {
case oauth2.AccessToken:
sId, err := r.FindSessionId(c, t)
if err == nil && sId != "" && r.timeoutApplier != nil {
valid, _ := r.timeoutApplier.ApplyTimeout(c, sId)
return valid
} else {
return r.exists(c, keyFuncAccessTokenToDetails(uniqueTokenKey(t)))
}
default:
return false
}
}
/**********************************
auth.AuthorizationRegistry
**********************************/
// RegisterRefreshToken save relationships :
// - RefreshToken -> Authentication "ART"
// - RefreshToken <- User & Client "RUC"
// - RefreshToken -> SessionId "RS"
func (r *RedisContextDetailsStore) RegisterRefreshToken(c context.Context, token oauth2.RefreshToken, oauth oauth2.Authentication) error {
if e := r.saveRefreshTokenToAuth(c, token, oauth); e != nil {
return e
}
if e := r.saveRefreshTokenFromUserClient(c, token, oauth); e != nil {
return e
}
ext := oauth.OAuth2Request().Extensions()
if ext != nil {
saveToSession, ok := ext[oauth2.ExtUseSessionTimeout].(bool)
if ok && saveToSession {
if e := r.saveRefreshTokenToSession(c, token, oauth); e != nil {
return e
}
}
}
return nil
}
// RegisterAccessToken save relationships :
// - AccessToken <- User & Client "AUC"
// - AccessToken -> SessionId "AS"
// - RefreshToken <-> AccessToken "AR"
func (r *RedisContextDetailsStore) RegisterAccessToken(ctx context.Context, token oauth2.AccessToken, oauth oauth2.Authentication) error {
if e := r.saveAccessTokenFromUserClient(ctx, token, oauth); e != nil {
return e
}
ext := oauth.OAuth2Request().Extensions()
if ext != nil {
saveToSession, ok := ext[oauth2.ExtUseSessionTimeout].(bool)
if ok && saveToSession {
if e := r.saveAccessTokenToSession(ctx, token, oauth); e != nil {
return e
}
}
}
if e := r.saveAccessRefreshTokenRelation(ctx, token); e != nil {
return e
}
return nil
}
func (r *RedisContextDetailsStore) ReadStoredAuthorization(c context.Context, token oauth2.RefreshToken) (oauth2.Authentication, error) {
return r.loadAuthFromRefreshToken(c, token)
}
func (r *RedisContextDetailsStore) FindSessionId(ctx context.Context, token oauth2.Token) (string, error) {
switch t := token.(type) {
case oauth2.AccessToken:
return r.loadSessionId(ctx, keyFuncAccessTokenFromSession(uniqueTokenKey(t), "*"))
case oauth2.RefreshToken:
return r.loadSessionId(ctx, keyFuncRefreshTokenFromSession(uniqueTokenKey(t), "*"))
default:
return "", fmt.Errorf(errTmplUnsupportedKey, token)
}
}
// RevokeRefreshToken remove redis records:
// - RefreshToken -> Authentication "ART"
// - RefreshToken <- User & Client "RUC"
// - RefreshToken -> SessionId "RS"
// - All Access Tokens (Each implicitly remove AccessToken <-> RefreshToken "AR")
func (r *RedisContextDetailsStore) RevokeRefreshToken(ctx context.Context, token oauth2.RefreshToken) error {
return r.doRemoveRefreshToken(ctx, token, "")
}
// RevokeAccessToken remove redis records:
// - AccessToken -> ContextDetails "AAT"
// - AccessToken <- User & Client "AUC"
// - AccessToken -> SessionId "AS"
// - AccessToken <-> RefreshToken "AR"
func (r *RedisContextDetailsStore) RevokeAccessToken(ctx context.Context, token oauth2.AccessToken) error {
return r.doRemoveAccessToken(ctx, token, "")
}
// RevokeAllAccessTokens remove all access tokens associated with given refresh token,
// with help of AccessToken <-> RefreshToken "AR" records
func (r *RedisContextDetailsStore) RevokeAllAccessTokens(ctx context.Context, token oauth2.RefreshToken) error {
rtk := uniqueTokenKey(token)
_, e := r.doRemoveAllAccessTokens(ctx, keyFuncAccessFromRefresh("*", rtk))
return e
}
// RevokeUserAccess remove all access/refresh tokens issued to the given user,
// with help of AccessToken <- User & Client "AUC" & RefreshToken <- User & Client "RUC" records
func (r *RedisContextDetailsStore) RevokeUserAccess(ctx context.Context, username string, revokeRefreshToken bool) error {
if revokeRefreshToken {
if _, e := r.doRemoveAllRefreshTokens(ctx, keyFuncRefreshTokenFromUserAndClient("*", username, "*")); e != nil {
return e
}
}
_, e := r.doRemoveAllAccessTokens(ctx, keyFuncAccessTokenFromUserAndClient("*", username, "*"))
return e
}
// RevokeClientAccess remove all access/refresh tokens issued to the given client,
// with help of AccessToken <- User & Client "AUC" & RefreshToken <- User & Client "RUC" records
func (r *RedisContextDetailsStore) RevokeClientAccess(ctx context.Context, clientId string, revokeRefreshToken bool) error {
if revokeRefreshToken {
if _, e := r.doRemoveAllRefreshTokens(ctx, keyFuncRefreshTokenFromUserAndClient("*", "*", clientId)); e != nil {
return e
}
}
_, e := r.doRemoveAllAccessTokens(ctx, keyFuncAccessTokenFromUserAndClient("*", "*", clientId))
return e
}
// RevokeSessionAccess remove all access/refresh tokens issued under given session,
// with help of AccessToken <- SessionId "RS" & RefreshToken <- SessionId "RS"
func (r *RedisContextDetailsStore) RevokeSessionAccess(ctx context.Context, sessionId string, revokeRefreshToken bool) error {
if revokeRefreshToken {
if _, e := r.doRemoveAllRefreshTokens(ctx, keyFuncRefreshTokenFromSession("*", sessionId)); e != nil {
return e
}
}
_, e := r.doRemoveAllAccessTokens(ctx, keyFuncAccessTokenFromSession("*", sessionId))
return e
}
/*
********************
Common Helpers
********************
*/
func (r *RedisContextDetailsStore) doSave(c context.Context, keyFunc keyFunc, value interface{}, expiry time.Time) error {
v, e := json.Marshal(value)
if e != nil {
return e
}
k := keyFunc(r.vTag)
ttl := time.Duration(redis.KeepTTL)
now := time.Now()
if expiry.After(now) {
ttl = expiry.Sub(now)
}
status := r.client.Set(c, k, v, ttl)
return status.Err()
}
func (r *RedisContextDetailsStore) doLoad(c context.Context, keyFunc keyFunc, value interface{}) error {
k := keyFunc(r.vTag)
cmd := r.client.Get(c, k)
if cmd.Err() != nil {
return cmd.Err()
}
return json.Unmarshal([]byte(cmd.Val()), value)
}
func (r *RedisContextDetailsStore) doDelete(c context.Context, keyFunc keyFunc) (int, error) {
k := keyFunc(r.vTag)
cmd := r.client.Del(c, k)
return int(cmd.Val()), cmd.Err()
}
func (r *RedisContextDetailsStore) doDeleteWithKeys(c context.Context, keys []string) (int, error) {
if len(keys) == 0 {
return 0, nil
}
cmd := r.client.Del(c, keys...)
return int(cmd.Val()), cmd.Err()
}
func (r *RedisContextDetailsStore) doDeleteWithWildcard(c context.Context, keyFunc keyFunc) (int, error) {
keys, e := r.doList(c, keyFunc)
if e != nil {
return 0, e
} else if len(keys) == 0 {
return 0, nil
}
cmd := r.client.Del(c, keys...)
return int(cmd.Val()), cmd.Err()
}
func (r *RedisContextDetailsStore) doList(c context.Context, keyFunc keyFunc) ([]string, error) {
k := keyFunc(r.vTag)
cmd := r.client.Keys(c, k)
if cmd.Err() != nil {
return nil, cmd.Err()
}
return cmd.Val(), nil
}
func (r *RedisContextDetailsStore) exists(c context.Context, keyFunc keyFunc) bool {
k := keyFunc(r.vTag)
cmd := r.client.Exists(c, k)
return cmd.Err() == nil && cmd.Val() != 0
}
func (r *RedisContextDetailsStore) doMultiDelete(actions ...func() (int, error)) (err error) {
for _, f := range actions {
if _, e := f(); e != nil {
err = e
}
}
return
}
/*
********************
Access Token
********************
*/
func (r *RedisContextDetailsStore) saveAccessTokenToDetails(c context.Context, t oauth2.AccessToken, details security.ContextDetails) error {
if e := r.doSave(c, keyFuncAccessTokenToDetails(uniqueTokenKey(t)), details, t.ExpiryTime()); e != nil {
return e
}
return nil
}
func (r *RedisContextDetailsStore) saveAccessTokenFromUserClient(c context.Context, t oauth2.AccessToken, oauth oauth2.Authentication) error {
clientId := oauth.OAuth2Request().ClientId()
username, _ := security.GetUsername(oauth.UserAuthentication())
atk := uniqueTokenKey(t)
rl := internal.RelationTokenUserClient{
Username: username,
ClientId: clientId,
RelationToken: internal.RelationToken{TokenKey: atk},
}
return r.doSave(c, keyFuncAccessTokenFromUserAndClient(atk, username, clientId), &rl, t.ExpiryTime())
}
func (r *RedisContextDetailsStore) saveAccessTokenToSession(c context.Context, t oauth2.AccessToken, oauth oauth2.Authentication) error {
sid := r.findSessionId(c, oauth)
if sid == "" {
return nil
}
atk := uniqueTokenKey(t)
rl := internal.RelationTokenSession{
SessionId: sid,
RelationToken: internal.RelationToken{TokenKey: atk},
}
return r.doSave(c, keyFuncAccessTokenFromSession(atk, sid), &rl, t.ExpiryTime())
}
func (r *RedisContextDetailsStore) saveAccessRefreshTokenRelation(c context.Context, t oauth2.AccessToken) error {
if t.RefreshToken() == nil {
return nil
}
atk := uniqueTokenKey(t)
rtk := uniqueTokenKey(t.RefreshToken())
rl := internal.RelationAccessRefresh{
RelationToken: internal.RelationToken{TokenKey: atk},
RefreshTokenKey: rtk,
}
return r.doSave(c, keyFuncAccessFromRefresh(atk, rtk), &rl, t.ExpiryTime())
}
func (r *RedisContextDetailsStore) loadDetailsFromAccessToken(c context.Context, t oauth2.AccessToken) (security.ContextDetails, error) {
sId, err := r.FindSessionId(c, t)
if err == nil && sId != "" && r.timeoutApplier != nil {
valid, _ := r.timeoutApplier.ApplyTimeout(c, sId)
if !valid {
return nil, errors.New("token is invalid because it's expired by its associate session")
}
}
fullDetails := internal.NewClientUserTenantedContextDetails()
if e := r.doLoad(c, keyFuncAccessTokenToDetails(uniqueTokenKey(t)), &fullDetails); e != nil {
return nil, e
}
if fullDetails.User.Id == "" || fullDetails.User.Username == "" {
// no user details, we assume it's a simple context
if fullDetails.Tenant.Id == "" {
return &internal.ClientContextDetails{
Authentication: fullDetails.Authentication,
KV: fullDetails.KV,
Client: fullDetails.Client,
}, nil
} else {
return &internal.ClientTenantedContextDetails{
ClientContextDetails: internal.ClientContextDetails{
Authentication: fullDetails.Authentication,
KV: fullDetails.KV,
Client: fullDetails.Client,
},
Tenant: fullDetails.Tenant,
}, nil
}
} else {
if fullDetails.Tenant.Id == "" {
return &internal.ClientUserContextDetails{
User: fullDetails.User,
Client: fullDetails.Client,
TenantAccess: fullDetails.TenantAccess,
Authentication: fullDetails.Authentication,
KV: fullDetails.KV,
}, nil
}
}
return fullDetails, nil
}
func (r *RedisContextDetailsStore) loadSessionId(ctx context.Context, keyfunc keyFunc) (string, error) {
keys, e := r.doList(ctx, keyfunc)
if e != nil {
return "", e
} else if len(keys) == 0 {
return "", fmt.Errorf("session ID not found for token")
}
rl := internal.RelationTokenSession{}
if e := r.doLoad(ctx, keyFuncLiteral(keys[0]), &rl); e != nil {
return "", e
}
return rl.SessionId, nil
}
func (r *RedisContextDetailsStore) doRemoveDetials(ctx context.Context, token oauth2.AccessToken, atk string) (int, error) {
if token != nil {
atk = uniqueTokenKey(token)
}
i, e := r.doDelete(ctx, keyFuncAccessTokenToDetails(atk))
if e != nil {
return 0, e
}
return i, nil
}
// - AccessToken -> ContextDetails "AAT"
// - AccessToken <- User & Client "AUC"
// - AccessToken -> SessionId "AS"
// - AccessToken <-> RefreshToken "AR"
func (r *RedisContextDetailsStore) doRemoveAccessToken(ctx context.Context, token oauth2.AccessToken, atk string) error {
if token != nil {
atk = uniqueTokenKey(token)
}
return r.doMultiDelete([]func() (int, error){
func() (int, error) { return r.doRemoveDetials(ctx, token, atk) },
func() (int, error) {
return r.doDeleteWithWildcard(ctx, keyFuncAccessTokenFromUserAndClient(atk, "*", "*"))
},
func() (int, error) { return r.doDeleteWithWildcard(ctx, keyFuncAccessTokenFromSession(atk, "*")) },
func() (int, error) { return r.doDeleteWithWildcard(ctx, keyFuncAccessFromRefresh(atk, "*")) },
}...)
}
func (r *RedisContextDetailsStore) doRemoveAllAccessTokens(ctx context.Context, keyfunc keyFunc) (count int, err error) {
keys, e := r.doList(ctx, keyfunc)
if e != nil {
return 0, e
}
count = len(keys)
for _, key := range keys {
rl := internal.RelationToken{}
if e := r.doLoad(ctx, keyFuncLiteral(key), &rl); e != nil {
continue
}
if e := r.doRemoveAccessToken(ctx, nil, rl.TokenKey); e != nil {
err = e
}
}
if _, e := r.doDeleteWithKeys(ctx, keys); e != nil {
err = e
}
return
}
/*
********************
Refresh Token
********************
*/
func (r *RedisContextDetailsStore) saveRefreshTokenToAuth(c context.Context, t oauth2.RefreshToken, oauth oauth2.Authentication) error {
return r.doSave(c, keyFuncRefreshTokenToAuth(uniqueTokenKey(t)), oauth, t.ExpiryTime())
}
func (r *RedisContextDetailsStore) saveRefreshTokenFromUserClient(c context.Context, t oauth2.RefreshToken, oauth oauth2.Authentication) error {
clientId := oauth.OAuth2Request().ClientId()
username, _ := security.GetUsername(oauth.UserAuthentication())
rtk := uniqueTokenKey(t)
rl := internal.RelationTokenUserClient{
Username: username,
ClientId: clientId,
RelationToken: internal.RelationToken{TokenKey: rtk},
}
return r.doSave(c, keyFuncRefreshTokenFromUserAndClient(rtk, username, clientId), &rl, t.ExpiryTime())
}
func (r *RedisContextDetailsStore) saveRefreshTokenToSession(c context.Context, t oauth2.RefreshToken, oauth oauth2.Authentication) error {
sid := r.findSessionId(c, oauth)
if sid == "" {
return nil
}
rtk := uniqueTokenKey(t)
rl := internal.RelationTokenSession{
SessionId: sid,
RelationToken: internal.RelationToken{TokenKey: rtk},
}
return r.doSave(c, keyFuncRefreshTokenFromSession(rtk, sid), &rl, t.ExpiryTime())
}
func (r *RedisContextDetailsStore) loadAuthFromRefreshToken(c context.Context, t oauth2.RefreshToken) (oauth2.Authentication, error) {
sId, err := r.FindSessionId(c, t)
if err == nil && sId != "" && r.timeoutApplier != nil {
valid, _ := r.timeoutApplier.ApplyTimeout(c, sId)
if !valid {
return nil, errors.New("token is invalid because it's expired by its associate session")
}
}
oauth := oauth2.NewAuthentication(func(opt *oauth2.AuthOption) {
opt.Request = oauth2.NewOAuth2Request()
opt.UserAuth = oauth2.NewUserAuthentication()
opt.Details = map[string]interface{}{}
})
if e := r.doLoad(c, keyFuncRefreshTokenToAuth(uniqueTokenKey(t)), &oauth); e != nil {
return nil, e
}
return oauth, nil
}
// - RefreshToken -> Authentication "ART"
// - RefreshToken <- User & Client "RUC"
// - RefreshToken -> SessionId "RS"
// - All Access Tokens (Each implicitly remove AccessToken <-> RefreshToken "AR")
func (r *RedisContextDetailsStore) doRemoveRefreshToken(ctx context.Context, token oauth2.RefreshToken, rtk string) error {
if token != nil {
rtk = uniqueTokenKey(token)
}
return r.doMultiDelete([]func() (int, error){
func() (int, error) { return r.doDelete(ctx, keyFuncRefreshTokenToAuth(rtk)) },
func() (int, error) {
return r.doDeleteWithWildcard(ctx, keyFuncRefreshTokenFromUserAndClient(rtk, "*", "*"))
},
func() (int, error) { return r.doDeleteWithWildcard(ctx, keyFuncRefreshTokenFromSession(rtk, "*")) },
func() (int, error) { return r.doRemoveAllAccessTokens(ctx, keyFuncAccessFromRefresh("*", rtk)) },
}...)
}
func (r *RedisContextDetailsStore) doRemoveAllRefreshTokens(ctx context.Context, keyfunc keyFunc) (count int, err error) {
keys, e := r.doList(ctx, keyfunc)
if e != nil {
return 0, e
}
count = len(keys)
for _, key := range keys {
rl := internal.RelationToken{}
if e := r.doLoad(ctx, keyFuncLiteral(key), &rl); e != nil {
continue
}
if e := r.doRemoveRefreshToken(ctx, nil, rl.TokenKey); e != nil {
err = e
}
}
if _, e = r.doDeleteWithKeys(ctx, keys); e != nil {
err = e
}
return
}
/*
********************
Other Helpers
********************
*/
func (r *RedisContextDetailsStore) findSessionId(_ context.Context, oauth oauth2.Authentication) (ret string) {
// try get it from UserAuthentaction first.
// this should works on non-proxied authentications
var sid interface{}
defer func() {
if s, ok := sid.(string); ok {
ret = s
}
}()
if userAuth, ok := oauth.UserAuthentication().(oauth2.UserAuthentication); ok && userAuth.DetailsMap() != nil {
if sid, ok = userAuth.DetailsMap()[security.DetailsKeySessionId]; ok && sid != "" {
return
}
}
// in case of proxied authentications, this value should be carried from KeyValueDetails
if kvs, ok := oauth.Details().(security.KeyValueDetails); ok {
if sid, ok = kvs.Value(security.DetailsKeySessionId); ok && sid != "" {
return
}
}
return
}
/*
********************
Keys
********************
*/
type keyFunc func(tag string) string
func keyFuncLiteral(key string) keyFunc {
return func(tag string) string {
return key
}
}
func keyFuncAccessTokenToDetails(atk string) keyFunc {
return func(tag string) string {
return fmt.Sprintf("%s:%s:%s", prefixAccessTokenToDetails, tag, atk) // SuppressWarnings go:S1192
}
}
func keyFuncAccessTokenFromUserAndClient(atk, username, clientId string) keyFunc {
return func(tag string) string {
return fmt.Sprintf("%s:%s:%s:%s:%s", prefixAccessTokenFromUserAndClient, tag, username, clientId, atk) // SuppressWarnings go:S1192
}
}
func keyFuncAccessTokenFromSession(atk, sid string) keyFunc {
return func(tag string) string {
return fmt.Sprintf("%s:%s:%s:%s", prefixAccessTokenFromSessionId, tag, sid, atk) // SuppressWarnings go:S1192
}
}
func keyFuncAccessFromRefresh(atk, rtk string) keyFunc {
return func(tag string) string {
return fmt.Sprintf("%s:%s:%s:%s", prefixAccessFromRefreshToken, tag, atk, rtk) // SuppressWarnings go:S1192
}
}
func keyFuncRefreshTokenToAuth(rtk string) keyFunc {
return func(tag string) string {
return fmt.Sprintf("%s:%s:%s", prefixRefreshTokenToAuthentication, tag, rtk) // SuppressWarnings go:S1192
}
}
func keyFuncRefreshTokenFromUserAndClient(rtk, username, clientId string) keyFunc {
return func(tag string) string {
return fmt.Sprintf("%s:%s:%s:%s:%s", prefixRefreshTokenFromUserAndClient, tag, username, clientId, rtk) // SuppressWarnings go:S1192
}
}
func keyFuncRefreshTokenFromSession(rtk, sid string) keyFunc {
return func(tag string) string {
return fmt.Sprintf("%s:%s:%s:%s", prefixRefreshTokenFromSessionId, tag, sid, rtk) // SuppressWarnings go:S1192
}
}
func uniqueTokenKey(token oauth2.Token) string {
// use JTI if possible
if t, ok := token.(oauth2.ClaimsContainer); ok && t.Claims() != nil {
if jti, ok := t.Claims().Get(oauth2.ClaimJwtId).(string); ok && jti != "" {
return jti
}
}
// use a hash of value
hash := sha256.Sum224([]byte(token.Value()))
return fmt.Sprintf("%x", hash)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import "github.com/cisco-open/go-lanai/pkg/security/oauth2"
// ExtendedClaims imlements oauth2.Claims. It's used only for access token decoding
type ExtendedClaims struct {
oauth2.FieldClaimsMapper
oauth2.BasicClaims
oauth2.Claims
}
func NewExtendedClaims(claims ...oauth2.Claims) *ExtendedClaims {
ptr := &ExtendedClaims{
Claims: oauth2.MapClaims{},
}
for _, c := range claims {
values := c.Values()
for k, v := range values {
ptr.Set(k, v)
}
}
return ptr
}
func (c *ExtendedClaims) MarshalJSON() ([]byte, error) {
return c.FieldClaimsMapper.DoMarshalJSON(c)
}
func (c *ExtendedClaims) UnmarshalJSON(bytes []byte) error {
return c.FieldClaimsMapper.DoUnmarshalJSON(c, bytes)
}
func (c *ExtendedClaims) Get(claim string) interface{} {
return c.FieldClaimsMapper.Get(c, claim)
}
func (c *ExtendedClaims) Has(claim string) bool {
return c.FieldClaimsMapper.Has(c, claim)
}
func (c *ExtendedClaims) Set(claim string, value interface{}) {
c.FieldClaimsMapper.Set(c, claim, value)
}
func (c *ExtendedClaims) Values() map[string]interface{} {
return c.FieldClaimsMapper.Values(c)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
type ProviderDetails struct {
Id string
Name string
DisplayName string
Description string
NotificationType string
Email string
}
type TenantDetails struct {
Id string
ExternalId string
Suspended bool
}
type UserDetails struct {
Id string
Username string
AccountType security.AccountType
AssignedTenantIds utils.StringSet
LocaleCode string
CurrencyCode string
FirstName string
LastName string
Email string
}
type ClientDetails struct {
Id string
AssignedTenantIds utils.StringSet
Scopes utils.StringSet
}
type TenantAccessDetails struct {
EffectiveAssignedTenantIds utils.StringSet
}
type AuthenticationDetails struct {
IssueTime time.Time
ExpiryTime time.Time
Roles utils.StringSet
Permissions utils.StringSet
AuthenticationTime time.Time
OriginalUsername string
Proxied bool
}
// ClientUserContextDetails implements
// - security.UserDetails
// - security.AuthenticationDetails
// - security.ProxiedUserDetails
// - security.KeyValueDetails
// - oauth2.ClientDetails
type ClientUserContextDetails struct {
User UserDetails
Client ClientDetails
TenantAccess TenantAccessDetails
Authentication AuthenticationDetails
KV map[string]interface{}
}
func (d *ClientUserContextDetails) ClientId() string {
return d.Client.Id
}
func (d *ClientUserContextDetails) Scopes() utils.StringSet {
return d.Client.Scopes
}
func NewClientUserContextDetails() *ClientUserContextDetails {
return &ClientUserContextDetails{
Client: ClientDetails{
AssignedTenantIds: utils.NewStringSet(),
Scopes: utils.NewStringSet(),
},
User: UserDetails{
AssignedTenantIds: utils.NewStringSet(),
},
Authentication: AuthenticationDetails{
Roles: utils.NewStringSet(),
Permissions: utils.NewStringSet(),
},
KV: map[string]interface{}{},
TenantAccess: TenantAccessDetails{
EffectiveAssignedTenantIds: utils.NewStringSet(),
},
}
}
// security.UserDetails
func (d *ClientUserContextDetails) UserId() string {
return d.User.Id
}
// security.UserDetails
func (d *ClientUserContextDetails) Username() string {
return d.User.Username
}
// security.UserDetails
func (d *ClientUserContextDetails) AccountType() security.AccountType {
return d.User.AccountType
}
// security.UserDetails
// Deprecated: the interface is deprecated
func (d *ClientUserContextDetails) AssignedTenantIds() utils.StringSet {
return d.User.AssignedTenantIds
}
// security.UserDetails
func (d *ClientUserContextDetails) LocaleCode() string {
return d.User.LocaleCode
}
// security.UserDetails
func (d *ClientUserContextDetails) CurrencyCode() string {
return d.User.CurrencyCode
}
// security.UserDetails
func (d *ClientUserContextDetails) FirstName() string {
return d.User.FirstName
}
// security.UserDetails
func (d *ClientUserContextDetails) LastName() string {
return d.User.LastName
}
// security.UserDetails
func (d *ClientUserContextDetails) Email() string {
return d.User.Email
}
// security.AuthenticationDetails
func (d *ClientUserContextDetails) ExpiryTime() time.Time {
return d.Authentication.ExpiryTime
}
// security.AuthenticationDetails
func (d *ClientUserContextDetails) IssueTime() time.Time {
return d.Authentication.IssueTime
}
// security.AuthenticationDetails
func (d *ClientUserContextDetails) Roles() utils.StringSet {
return d.Authentication.Roles
}
// security.AuthenticationDetails
func (d *ClientUserContextDetails) Permissions() utils.StringSet {
return d.Authentication.Permissions
}
// security.AuthenticationDetails
func (d *ClientUserContextDetails) AuthenticationTime() time.Time {
return d.Authentication.AuthenticationTime
}
// security.ProxiedUserDetails
func (d *ClientUserContextDetails) OriginalUsername() string {
return d.Authentication.OriginalUsername
}
// security.ProxiedUserDetails
func (d *ClientUserContextDetails) Proxied() bool {
return d.Authentication.Proxied
}
// security.KeyValueDetails
func (d *ClientUserContextDetails) Value(key string) (v interface{}, ok bool) {
v, ok = d.KV[key]
return
}
// security.KeyValueDetails
func (d *ClientUserContextDetails) Values() (ret map[string]interface{}) {
ret = map[string]interface{}{}
for k, v := range d.KV {
ret[k] = v
}
return
}
func (d *ClientUserContextDetails) EffectiveAssignedTenantIds() utils.StringSet {
return d.TenantAccess.EffectiveAssignedTenantIds
}
// ClientUserContextDetails implements
// - security.UserDetails
// - security.TenantDetails
// - security.ProviderDetails
// - security.AuthenticationDetails
// - security.ProxiedUserDetails
// - security.KeyValueDetails
// - oauth2.ClientDetails
type ClientUserTenantedContextDetails struct {
ClientUserContextDetails
Tenant TenantDetails
Provider ProviderDetails
}
func NewClientUserTenantedContextDetails() *ClientUserTenantedContextDetails {
return &ClientUserTenantedContextDetails{
ClientUserContextDetails: ClientUserContextDetails{
Client: ClientDetails{
AssignedTenantIds: utils.NewStringSet(),
Scopes: utils.NewStringSet(),
},
User: UserDetails{
AssignedTenantIds: utils.NewStringSet(),
},
Authentication: AuthenticationDetails{
Roles: utils.NewStringSet(),
Permissions: utils.NewStringSet(),
},
KV: map[string]interface{}{},
TenantAccess: TenantAccessDetails{
EffectiveAssignedTenantIds: utils.NewStringSet(),
},
},
Tenant: TenantDetails{},
Provider: ProviderDetails{},
}
}
// security.TenantDetails
func (d *ClientUserTenantedContextDetails) TenantId() string {
return d.Tenant.Id
}
// security.TenantDetails
func (d *ClientUserTenantedContextDetails) TenantExternalId() string {
return d.Tenant.ExternalId
}
// security.TenantDetails
func (d *ClientUserTenantedContextDetails) TenantSuspended() bool {
return d.Tenant.Suspended
}
// security.ProviderDetails
func (d *ClientUserTenantedContextDetails) ProviderId() string {
return d.Provider.Id
}
// security.ProviderDetails
func (d *ClientUserTenantedContextDetails) ProviderName() string {
return d.Provider.Name
}
// security.ProviderDetails
func (d *ClientUserTenantedContextDetails) ProviderDisplayName() string {
return d.Provider.DisplayName
}
func (d *ClientUserTenantedContextDetails) ProviderDescription() string {
return d.Provider.Description
}
func (d *ClientUserTenantedContextDetails) ProviderEmail() string {
return d.Provider.Email
}
func (d *ClientUserTenantedContextDetails) ProviderNotificationType() string {
return d.Provider.NotificationType
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
// ClientContextDetails implements
// - security.AuthenticationDetails
// - security.KeyValueDetails
// - oauth2.ClientDetails
// It is used to represent a client credential
type ClientContextDetails struct {
Authentication AuthenticationDetails
Client ClientDetails
KV map[string]interface{}
TenantAccess TenantAccessDetails
}
func (d *ClientContextDetails) ClientId() string {
return d.Client.Id
}
func (d *ClientContextDetails) AssignedTenantIds() utils.StringSet {
return d.Client.AssignedTenantIds
}
func (d *ClientContextDetails) Scopes() utils.StringSet {
return d.Client.Scopes
}
// security.AuthenticationDetails
func (d *ClientContextDetails) ExpiryTime() time.Time {
return d.Authentication.ExpiryTime
}
// security.AuthenticationDetails
func (d *ClientContextDetails) IssueTime() time.Time {
return d.Authentication.IssueTime
}
// security.AuthenticationDetails
func (d *ClientContextDetails) Roles() utils.StringSet {
return d.Authentication.Roles
}
// security.AuthenticationDetails
func (d *ClientContextDetails) Permissions() utils.StringSet {
return d.Authentication.Permissions
}
// security.AuthenticationDetails
func (d *ClientContextDetails) AuthenticationTime() time.Time {
return d.Authentication.AuthenticationTime
}
// security.KeyValueDetails
func (d *ClientContextDetails) Value(key string) (v interface{}, ok bool) {
v, ok = d.KV[key]
return
}
// security.KeyValueDetails
func (d *ClientContextDetails) Values() (ret map[string]interface{}) {
ret = map[string]interface{}{}
for k, v := range d.KV {
ret[k] = v
}
return
}
// ClientTenantedContextDetails implements
// - security.AuthenticationDetails
// - security.KeyValueDetails
// - security.TenantDetails
// - security.ProviderDetails
// - oauth2.ClientDetails
// It is used to represent a client credential with selected tenant
type ClientTenantedContextDetails struct {
ClientContextDetails
Tenant TenantDetails
Provider ProviderDetails
}
func (d *ClientTenantedContextDetails) TenantId() string {
return d.Tenant.Id
}
func (d *ClientTenantedContextDetails) TenantExternalId() string {
return d.Tenant.ExternalId
}
func (d *ClientTenantedContextDetails) TenantSuspended() bool {
return d.Tenant.Suspended
}
// security.ProviderDetails
func (d *ClientTenantedContextDetails) ProviderId() string {
return d.Provider.Id
}
// security.ProviderDetails
func (d *ClientTenantedContextDetails) ProviderName() string {
return d.Provider.Name
}
// security.ProviderDetails
func (d *ClientTenantedContextDetails) ProviderDisplayName() string {
return d.Provider.DisplayName
}
func (d *ClientTenantedContextDetails) ProviderDescription() string {
return d.Provider.Description
}
func (d *ClientTenantedContextDetails) ProviderEmail() string {
return d.Provider.Email
}
func (d *ClientTenantedContextDetails) ProviderNotificationType() string {
return d.Provider.NotificationType
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package internal
import (
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
// DecodedAccessToken implements oauth2.AccessToken and oauth2.ClaimsContainer
type DecodedAccessToken struct {
DecodedClaims *ExtendedClaims
TokenValue string
ExpireAt time.Time
IssuedAt time.Time
ScopesSet utils.StringSet
}
func NewDecodedAccessToken() *DecodedAccessToken {
return &DecodedAccessToken{}
}
func (t *DecodedAccessToken) Value() string {
return t.TokenValue
}
func (t *DecodedAccessToken) ExpiryTime() time.Time {
return t.ExpireAt
}
func (t *DecodedAccessToken) Expired() bool {
return !t.ExpireAt.IsZero() && t.ExpireAt.Before(time.Now())
}
func (t *DecodedAccessToken) Details() map[string]interface{} {
return map[string]interface{}{}
}
func (t *DecodedAccessToken) Type() oauth2.TokenType {
return oauth2.TokenTypeBearer
}
func (t *DecodedAccessToken) IssueTime() time.Time {
return t.IssuedAt
}
func (t *DecodedAccessToken) Scopes() utils.StringSet {
return t.ScopesSet
}
func (t *DecodedAccessToken) RefreshToken() oauth2.RefreshToken {
return nil
}
// oauth2.ClaimsContainer
func (t *DecodedAccessToken) Claims() oauth2.Claims {
return t.DecodedClaims
}
// oauth2.ClaimsContainer
func (t *DecodedAccessToken) SetClaims(claims oauth2.Claims) {
if c, ok := claims.(*ExtendedClaims); ok {
t.DecodedClaims = c
return
}
t.DecodedClaims = NewExtendedClaims(claims)
}
// DecodedRefreshToken implements oauth2.RefreshToken and oauth2.ClaimsContainer
type DecodedRefreshToken struct {
DecodedClaims *ExtendedClaims
TokenValue string
ExpireAt time.Time
IssuedAt time.Time
ScopesSet utils.StringSet
}
func (t *DecodedRefreshToken) Value() string {
return t.TokenValue
}
func (t *DecodedRefreshToken) ExpiryTime() time.Time {
return t.ExpireAt
}
func (t *DecodedRefreshToken) Expired() bool {
return !t.ExpireAt.IsZero() && t.ExpireAt.Before(time.Now())
}
func (t *DecodedRefreshToken) Details() map[string]interface{} {
return map[string]interface{}{}
}
func (t *DecodedRefreshToken) WillExpire() bool {
return !t.ExpireAt.IsZero()
}
// oauth2.ClaimsContainer
func (t *DecodedRefreshToken) Claims() oauth2.Claims {
return t.DecodedClaims
}
// oauth2.ClaimsContainer
func (t *DecodedRefreshToken) SetClaims(claims oauth2.Claims) {
if c, ok := claims.(*ExtendedClaims); ok {
t.DecodedClaims = c
return
}
t.DecodedClaims = NewExtendedClaims(claims)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package common
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/common/internal"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/utils"
)
// jwtTokenStoreReader implements TokenStoreReader
type jwtTokenStoreReader struct {
detailsStore security.ContextDetailsStore
jwtDecoder jwt.JwtDecoder
}
type JTSROptions func(opt *JTSROption)
type JTSROption struct {
DetailsStore security.ContextDetailsStore
Decoder jwt.JwtDecoder
}
func NewJwtTokenStoreReader(opts...JTSROptions) *jwtTokenStoreReader {
opt := JTSROption{}
for _, optFunc := range opts {
optFunc(&opt)
}
return &jwtTokenStoreReader{
detailsStore: opt.DetailsStore,
jwtDecoder: opt.Decoder,
}
}
func (r *jwtTokenStoreReader) ReadAuthentication(ctx context.Context, tokenValue string, hint oauth2.TokenHint) (oauth2.Authentication, error) {
switch hint {
case oauth2.TokenHintAccessToken:
return r.readAuthenticationFromAccessToken(ctx, tokenValue)
default:
return nil, oauth2.NewUnsupportedTokenTypeError(fmt.Sprintf("token type [%s] is not supported", hint.String()))
}
}
func (r *jwtTokenStoreReader) ReadAccessToken(c context.Context, value string) (oauth2.AccessToken, error) {
token, e := r.parseAccessToken(c, value)
switch {
case e != nil:
return nil, oauth2.NewInvalidAccessTokenError("token is invalid", e)
case token.Expired():
return nil, oauth2.NewInvalidAccessTokenError("token is expired")
case !r.detailsStore.ContextDetailsExists(c, token):
return nil, oauth2.NewInvalidAccessTokenError("token is revoked")
}
return token, nil
}
func (r *jwtTokenStoreReader) ReadRefreshToken(c context.Context, value string) (oauth2.RefreshToken, error) {
token, e := r.parseRefreshToken(c, value)
switch {
case e != nil:
return nil, oauth2.NewInvalidGrantError("refresh token is invalid", e)
case token.WillExpire() && token.Expired():
return nil, oauth2.NewInvalidGrantError("refresh token is expired")
}
return token, nil
}
func (r *jwtTokenStoreReader) parseAccessToken(c context.Context, value string) (*internal.DecodedAccessToken, error) {
claims := internal.ExtendedClaims{}
if e := r.jwtDecoder.DecodeWithClaims(c, value, &claims); e != nil {
return nil, e
}
token := internal.DecodedAccessToken{}
token.TokenValue = value
token.DecodedClaims = &claims
token.ExpireAt = claims.ExpiresAt
token.IssuedAt = claims.IssuedAt
token.ScopesSet = claims.Scopes.Copy()
return &token, nil
}
func (r *jwtTokenStoreReader) parseRefreshToken(c context.Context, value string) (*internal.DecodedRefreshToken, error) {
claims := internal.ExtendedClaims{}
if e := r.jwtDecoder.DecodeWithClaims(c, value, &claims); e != nil {
return nil, e
}
token := internal.DecodedRefreshToken{}
token.TokenValue = value
token.DecodedClaims = &claims
token.ExpireAt = claims.ExpiresAt
token.IssuedAt = claims.IssuedAt
token.ScopesSet = claims.Scopes.Copy()
return &token, nil
}
func (r *jwtTokenStoreReader) readAuthenticationFromAccessToken(c context.Context, tokenValue string) (oauth2.Authentication, error) {
// parse JWT token
token, e := r.parseAccessToken(c, tokenValue)
if e != nil {
return nil, e
}
claims := token.DecodedClaims
if claims == nil {
return nil, oauth2.NewInvalidAccessTokenError("token contains no claims")
}
// load context details
details, e := r.detailsStore.ReadContextDetails(c, token)
if e != nil {
return nil, oauth2.NewInvalidAccessTokenError("token unknown", e)
}
// reconstruct request
request := r.createOAuth2Request(claims, details)
// reconstruct user auth if available
var userAuth security.Authentication
if claims.Subject != "" {
userAuth = r.createUserAuthentication(claims, details)
}
return oauth2.NewAuthentication(func(opt *oauth2.AuthOption) {
opt.Request = request
opt.UserAuth = userAuth
opt.Token = token
opt.Details = details
}), nil
}
/*****************
Helpers
*****************/
func (r *jwtTokenStoreReader) createOAuth2Request(claims *internal.ExtendedClaims, details security.ContextDetails) oauth2.OAuth2Request {
clientId := claims.ClientId
if clientId == "" && claims.Audience != nil && len(claims.Audience) != 0 {
clientId = utils.StringSet(claims.Audience).Values()[0]
}
params := map[string]string{}
reqParams, _ := details.Value(oauth2.DetailsKeyRequestParams)
if m, ok := reqParams.(map[string]interface{}); ok {
for k, v := range m {
switch s := v.(type) {
case string:
params[k] = s
}
}
}
ext := claims.Values()
reqExt, _ := details.Value(oauth2.DetailsKeyRequestExt)
if m, ok := reqExt.(map[string]interface{}); ok {
for k, v := range m {
ext[k] = v
}
}
return oauth2.NewOAuth2Request(func(opt *oauth2.RequestDetails) {
opt.Parameters = params
opt.ClientId = clientId
opt.Scopes = claims.Scopes
opt.Approved = true
opt.Extensions = ext
//opt.GrantType =
//opt.RedirectUri =
//opt.ResponseTypes =
})
}
func (r *jwtTokenStoreReader) createUserAuthentication(claims *internal.ExtendedClaims, details security.ContextDetails) security.Authentication {
permissions := map[string]interface{}{}
for k := range details.Permissions() {
permissions[k] = true
}
return oauth2.NewUserAuthentication(func(opt *oauth2.UserAuthOption) {
opt.Principal = claims.Subject
opt.Permissions = permissions
opt.State = security.StateAuthenticated
opt.Details = map[string]interface{}{}
if claims.Claims != nil {
opt.Details = claims.Claims.Values()
}
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
)
/******************************
security.Authentication
******************************/
// Authentication extends security.Authentication
type Authentication interface {
security.Authentication
UserAuthentication() security.Authentication
OAuth2Request() OAuth2Request
AccessToken() AccessToken
}
type AuthenticationOptions func(opt *AuthOption)
type AuthOption struct {
Request OAuth2Request
UserAuth security.Authentication
Token AccessToken
Details interface{}
}
// authentication
type authentication struct {
Request OAuth2Request `json:"request"`
UserAuth security.Authentication `json:"userAuth"`
AuthState security.AuthenticationState `json:"state"`
token AccessToken
details interface{}
}
func NewAuthentication(opts ...AuthenticationOptions) Authentication {
config := AuthOption{}
for _, opt := range opts {
opt(&config)
}
return &authentication{
Request: config.Request,
UserAuth: config.UserAuth,
AuthState: calculateState(config.Request, config.UserAuth),
token: config.Token,
details: config.Details,
}
}
func (a *authentication) Principal() interface{} {
if a.UserAuth == nil {
return a.Request.ClientId()
}
return a.UserAuth.Principal()
}
func (a *authentication) Permissions() security.Permissions {
if a.UserAuth == nil {
return map[string]interface{}{}
}
return a.UserAuth.Permissions()
}
func (a *authentication) State() security.AuthenticationState {
return a.AuthState
}
func (a *authentication) Details() interface{} {
return a.details
}
func (a *authentication) UserAuthentication() security.Authentication {
return a.UserAuth
}
func (a *authentication) OAuth2Request() OAuth2Request {
return a.Request
}
func (a *authentication) AccessToken() AccessToken {
return a.token
}
func calculateState(req OAuth2Request, userAuth security.Authentication) security.AuthenticationState {
if req != nil && req.Approved() {
if userAuth != nil {
return userAuth.State()
}
return security.StateAuthenticated
} else if userAuth != nil {
return security.StatePrincipalKnown
}
return security.StateAnonymous
}
/******************************
UserAuthentication
******************************/
type UserAuthentication interface {
security.Authentication
Subject() string
DetailsMap() map[string]interface{}
}
type UserAuthOptions func(opt *UserAuthOption)
type UserAuthOption struct {
Principal string
Permissions map[string]interface{}
State security.AuthenticationState
Details map[string]interface{}
}
// userAuthentication implements security.Authentication and UserAuthentication.
// it represents basic information that could be typically extracted from JWT claims
// userAuthentication is also used for serializing/deserializing
type userAuthentication struct {
SubjectVal string `json:"principal"`
PermissionVal map[string]interface{} `json:"permissions"`
StateVal security.AuthenticationState `json:"state"`
DetailsVal map[string]interface{} `json:"details"`
}
func NewUserAuthentication(opts ...UserAuthOptions) *userAuthentication {
opt := UserAuthOption{
Permissions: map[string]interface{}{},
Details: map[string]interface{}{},
}
for _, f := range opts {
f(&opt)
}
return &userAuthentication{
SubjectVal: opt.Principal,
PermissionVal: opt.Permissions,
StateVal: opt.State,
DetailsVal: opt.Details,
}
}
func (a *userAuthentication) Principal() interface{} {
return a.SubjectVal
}
func (a *userAuthentication) Permissions() security.Permissions {
return a.PermissionVal
}
func (a *userAuthentication) State() security.AuthenticationState {
return a.StateVal
}
func (a *userAuthentication) Details() interface{} {
return a.DetailsVal
}
func (a *userAuthentication) Subject() string {
return a.SubjectVal
}
func (a *userAuthentication) DetailsMap() map[string]interface{} {
return a.DetailsVal
}
/*********************
Timeout Support
*********************/
type TimeoutApplier interface {
ApplyTimeout(ctx context.Context, sessionId string) (valid bool, err error)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/utils"
"reflect"
"time"
)
const (
ClaimTag = "claim"
)
type Claims interface {
Get(claim string) interface{}
Has(claim string) bool
Set(claim string, value interface{})
Values() map[string]interface{}
}
// StringSetClaim is an alias of utils.StringSet with different JSON serialization specialized for some Claims
// StringSetClaim serialize as JSON string if there is single element in the set, otherwise as JSON array
type StringSetClaim utils.StringSet
// MarshalJSON json.Marshaler
func (s StringSetClaim) MarshalJSON() ([]byte, error) {
switch len(s) {
case 1:
var v string
for v = range s {
// SuppressWarnings go:S108 empty block is intended to get any entry in the set
}
return json.Marshal(v)
default:
return utils.StringSet(s).MarshalJSON()
}
}
// UnmarshalJSON json.Unmarshaler
func (s StringSetClaim) UnmarshalJSON(data []byte) error {
values := make([]string, 0)
if e := json.Unmarshal(data, &values); e == nil {
utils.StringSet(s).Add(values...)
return nil
}
// fallback to string
value := ""
if e := json.Unmarshal(data, &value); e != nil {
return e
}
if value != "" {
utils.StringSet(s).Add(value)
}
return nil
}
/*********************
Implements
*********************/
// MapClaims imlements Claims & claimsMapper
type MapClaims map[string]interface{}
func (c MapClaims) MarshalJSON() ([]byte, error) {
m, e := c.toMap(true)
if e != nil {
return nil, e
}
return json.Marshal(m)
}
func (c MapClaims) UnmarshalJSON(bytes []byte) error {
m := map[string]interface{}{}
if e := json.Unmarshal(bytes, &m); e != nil {
return e
}
return c.fromMap(m)
}
func (c MapClaims) Get(claim string) interface{} {
return c[claim]
}
func (c MapClaims) Has(claim string) bool {
_, ok := c[claim]
return ok
}
func (c MapClaims) Set(claim string, value interface{}) {
c[claim] = value
}
func (c MapClaims) Values() map[string]interface{} {
ret, e := c.toMap(false)
if e != nil {
return map[string]interface{}{}
}
return ret
}
func (c MapClaims) toMap(convert bool) (map[string]interface{}, error) {
ret := map[string]interface{}{}
for k, v := range c {
if convert {
value, e := claimMarshalConvert(reflect.ValueOf(v))
if e != nil {
return nil, e
}
ret[k] = value.Interface()
} else {
ret[k] = v
}
}
return ret, nil
}
func (c MapClaims) fromMap(src map[string]interface{}) error {
for k, v := range src {
value, e := claimUnmarshalConvert(reflect.ValueOf(v), anyType)
if e != nil {
return e
}
c[k] = value.Interface()
}
return nil
}
// BasicClaims imlements Claims
type BasicClaims struct {
FieldClaimsMapper
Audience StringSetClaim `claim:"aud"`
ExpiresAt time.Time `claim:"exp"`
Id string `claim:"jti"`
IssuedAt time.Time `claim:"iat"`
Issuer string `claim:"iss"`
NotBefore time.Time `claim:"nbf"`
Subject string `claim:"sub"`
Scopes utils.StringSet `claim:"scope"`
ClientId string `claim:"client_id"`
}
func (c *BasicClaims) MarshalJSON() ([]byte, error) {
return c.FieldClaimsMapper.DoMarshalJSON(c)
}
func (c *BasicClaims) UnmarshalJSON(bytes []byte) error {
return c.FieldClaimsMapper.DoUnmarshalJSON(c, bytes)
}
func (c *BasicClaims) Get(claim string) interface{} {
return c.FieldClaimsMapper.Get(c, claim)
}
func (c *BasicClaims) Has(claim string) bool {
return c.FieldClaimsMapper.Has(c, claim)
}
func (c *BasicClaims) Set(claim string, value interface{}) {
c.FieldClaimsMapper.Set(c, claim, value)
}
func (c *BasicClaims) Values() map[string]interface{} {
return c.FieldClaimsMapper.Values(c)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"reflect"
"time"
)
var (
stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
int64Type = reflect.TypeOf(int64(0))
float64Type = reflect.TypeOf(float64(0))
float32Type = reflect.TypeOf(float32(0))
sSliceType = reflect.TypeOf([]string{})
iSliceType = reflect.TypeOf([]interface{}{})
sSetType = reflect.TypeOf(utils.NewStringSet())
sSetClaimType = reflect.TypeOf(StringSetClaim(utils.NewStringSet()))
iSetType = reflect.TypeOf(utils.NewSet())
mapType = reflect.TypeOf(map[string]interface{}{})
anyType = reflect.TypeOf(interface{}(0))
)
// some conversions
func claimMarshalConvert(v reflect.Value) (reflect.Value, error) {
t := v.Type()
switch {
case timeType.AssignableTo(t):
return timeToTimestamp(v)
case float64Type.AssignableTo(t):
fallthrough
case float32Type.AssignableTo(t):
return v.Convert(int64Type), nil
default:
return v, nil
}
}
func claimUnmarshalConvert(v reflect.Value, fieldType reflect.Type) (reflect.Value, error) {
switch {
// special target types
case timeType.AssignableTo(fieldType):
return timestampToTime(v)
case sSetClaimType.AssignableTo(fieldType):
return toStringSetClaim(v)
case sSetType.AssignableTo(fieldType):
return toStringSet(v)
case iSetType.AssignableTo(fieldType):
return toSet(v)
case fieldType.Kind() == reflect.Ptr && fieldType.Elem().Kind() != reflect.Struct:
return toAddr(v)
case v.Type().AssignableTo(mapType) && isStructOrStructPtr(fieldType):
return mapToStruct(v, fieldType)
case sSliceType.AssignableTo(fieldType):
return toStringSlice(v)
// special source types
case v.Type().AssignableTo(float32Type):
fallthrough
case v.Type().AssignableTo(float64Type):
v = v.Convert(int64Type)
}
switch {
// convertable and assignable
case v.Type().AssignableTo(fieldType):
return v, nil
case v.Type().ConvertibleTo(fieldType):
return v.Convert(fieldType), nil
default:
return v, nil
}
}
func isStructOrStructPtr(t reflect.Type) bool {
return t.Kind() == reflect.Struct || t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}
func timestampToTime(v reflect.Value) (reflect.Value, error) {
switch {
case v.Type().ConvertibleTo(int64Type):
timestamp := v.Convert(int64Type).Interface().(int64)
return reflect.ValueOf(time.Unix(timestamp, 0)), nil
case v.Type().ConvertibleTo(timeType):
return v.Convert(timeType), nil
default:
return reflect.Value{}, fmt.Errorf("type %T cannot be converted to time.Time", v.Interface())
}
}
func timeToTimestamp(v reflect.Value) (reflect.Value, error) {
switch {
case v.Type().ConvertibleTo(timeType):
time := v.Convert(timeType).Interface().(time.Time)
return reflect.ValueOf(time.Unix()), nil
default:
return reflect.Value{}, fmt.Errorf("type %T cannot be converted to timestamp", v.Interface())
}
}
func toStringSet(v reflect.Value) (reflect.Value, error) {
switch {
case v.Type().ConvertibleTo(sSliceType):
slice := v.Convert(sSliceType).Interface().([]string)
return reflect.ValueOf(utils.NewStringSet(slice...)), nil
case v.Type().ConvertibleTo(iSliceType):
slice := v.Convert(iSliceType).Interface().([]interface{})
set := utils.NewStringSetFromSet(utils.NewSet(slice...))
return reflect.ValueOf(set), nil
default:
return reflect.Value{}, fmt.Errorf("type %T cannot be converted to string set", v.Interface())
}
}
func toStringSetClaim(v reflect.Value) (reflect.Value, error) {
var set utils.StringSet
switch {
case v.Type().ConvertibleTo(stringType):
str := v.Convert(stringType).Interface().(string)
set = utils.NewStringSet(str)
case v.Type().ConvertibleTo(sSliceType):
slice := v.Convert(sSliceType).Interface().([]string)
set = utils.NewStringSet(slice...)
case v.Type().ConvertibleTo(iSliceType):
slice := v.Convert(iSliceType).Interface().([]interface{})
set = utils.NewStringSetFromSet(utils.NewSet(slice...))
default:
return reflect.Value{}, fmt.Errorf("type %T cannot be converted to string set", v.Interface())
}
return reflect.ValueOf(StringSetClaim(set)), nil
}
func toSet(v reflect.Value) (reflect.Value, error) {
switch {
case v.Type().ConvertibleTo(sSliceType):
slice := v.Convert(sSliceType).Interface().([]string)
return reflect.ValueOf(utils.NewStringSet(slice...).ToSet()), nil
case v.Type().ConvertibleTo(iSliceType):
slice := v.Convert(iSliceType).Interface().([]interface{})
return reflect.ValueOf(utils.NewSet(slice...)), nil
default:
return reflect.Value{}, fmt.Errorf("type %T cannot be converted to set", v.Interface())
}
}
func toAddr(v reflect.Value) (reflect.Value, error) {
if v.CanAddr() {
return v.Addr(), nil
}
switch v.Kind() {
case reflect.Bool:
return reflect.ValueOf(utils.BoolPtr(v.Bool())), nil
case reflect.Int:
return reflect.ValueOf(utils.IntPtr(int(v.Int()))), nil
case reflect.Uint:
return reflect.ValueOf(utils.UIntPtr(uint(v.Uint()))), nil
case reflect.Float64:
return reflect.ValueOf(utils.Float64Ptr(v.Float())), nil
default:
return reflect.Value{}, fmt.Errorf("value [%v, %T] cannot be addressed", v.Interface(), v.Interface())
}
}
func toStringSlice(v reflect.Value) (reflect.Value, error) {
switch {
case v.Type().ConvertibleTo(sSliceType):
return v.Convert(sSliceType), nil
case v.Type().ConvertibleTo(iSliceType):
srcSlice := v.Convert(iSliceType)
slice := reflect.MakeSlice(sSliceType, srcSlice.Len(), srcSlice.Len())
for i := 0; i < srcSlice.Len(); i++ {
elem := srcSlice.Index(i).Elem()
if !elem.Type().ConvertibleTo(stringType) {
return reflect.Value{}, fmt.Errorf("type %T cannot be converted to []string, source contains non-string type %T", v.Interface(), elem.Interface())
}
slice.Index(i).Set(elem.Convert(stringType))
}
return slice, nil
default:
return reflect.Value{}, fmt.Errorf("type %T cannot be converted to []string", v.Interface())
}
}
func mapToStruct(v reflect.Value, ft reflect.Type) (reflect.Value, error) {
isPtr := false
if ft.Kind() == reflect.Ptr {
isPtr = true
ft = ft.Elem()
}
if ft.Kind() != reflect.Struct {
return reflect.Value{}, fmt.Errorf("map can only convert to struct or pointer of struct. got [%T]", ft.String())
}
// first instantiate
nv := reflect.New(ft)
// try convert
// instead of reflection, we use JSON to do the convert. This is much slower but safer
m := v.Interface()
data, e := json.Marshal(m)
if e != nil {
return reflect.Value{}, fmt.Errorf("map cannot be serialized to json: %v", e)
}
if e := json.Unmarshal(data, nv.Interface()); e != nil {
return reflect.Value{}, fmt.Errorf("json cannot be converted: %v", e)
}
if !isPtr {
nv = nv.Elem()
}
return nv, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"encoding/json"
"fmt"
"reflect"
)
/***************************
Struct Claims Helpers
***************************/
var (
claimsType = reflect.TypeOf(Claims(MapClaims{}))
mapClaimsType = reflect.TypeOf(MapClaims{})
mapperType = reflect.TypeOf(FieldClaimsMapper{})
)
type accumulator func(i interface{}, claims Claims) (accumulated interface{}, shouldContinue bool)
type claimsMapper interface {
toMap(owner interface{}, convert bool) (map[string]interface{}, error)
fromMap(owner interface{}, src map[string]interface{}) error
}
// FieldClaimsMapper is a helper type that can be embedded into struct based claims
// FieldClaimsMapper implements claimsMapper
// See BasicClaims as an example.
// Note: having non-claims struct as fields is not recommended for deserialization
type FieldClaimsMapper struct {
fields map[string][]int // Index of fields holding claim. Includes embedded structs
interfaces [][]int // Index of directly embedded Cliams interfaces
}
func (m *FieldClaimsMapper) Get(owner interface{}, claim string) interface{} {
v := m.findFieldValue(owner, claim)
if v.IsValid() && !v.IsZero() {
return v.Interface()
}
// try with all embedded Claims interface
return m.aggregateEmbeddedClaims(owner, nil, func(i interface{}, claims Claims) (interface{}, bool) {
if claims.Has(claim) {
// found it, don't continue
return claims.Get(claim), false
}
return nil, true
})
}
func (m *FieldClaimsMapper) Has(owner interface{}, claim string) bool {
v := m.findFieldValue(owner, claim)
if !v.IsValid() || v.IsZero() {
// try with all embedded Claims interface
return m.aggregateEmbeddedClaims(owner, false, func(i interface{}, claims Claims) (interface{}, bool) {
has := claims.Has(claim)
return has, !has
}).(bool)
}
return true
}
func (m *FieldClaimsMapper) Set(owner interface{}, claim string, value interface{}) {
v := m.findFieldValue(owner, claim)
if v.IsValid() {
if e := m.set(v, value); e != nil {
panic(e)
}
}
// try with all embedded Claims interface
m.aggregateEmbeddedClaims(owner, nil, func(i interface{}, claims Claims) (interface{}, bool) {
claims.Set(claim, value)
return nil, true
})
}
// return claims values as a map, without any conversion
func (m *FieldClaimsMapper) Values(owner interface{}) map[string]interface{} {
values, e := m.toMap(owner, false)
if e != nil {
return map[string]interface{}{}
}
return values
}
func (m *FieldClaimsMapper) DoMarshalJSON(owner interface{}) ([]byte, error) {
v, e := m.toMap(owner, true)
if e != nil {
return nil, e
}
return json.Marshal(v)
}
func (m *FieldClaimsMapper) DoUnmarshalJSON(owner interface{}, bytes []byte) error {
values := map[string]interface{}{}
if e := json.Unmarshal(bytes, &values); e != nil {
return e
}
if e := m.fromMap(owner, values); e != nil {
return e
}
return nil
}
func (m *FieldClaimsMapper) findFieldValue(owner interface{}, claim string) (ret reflect.Value) {
m.prepare(owner)
index, ok := m.fields[claim]
if !ok {
return
}
v := reflect.ValueOf(owner)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
return v.FieldByIndex(index)
}
func (m *FieldClaimsMapper) set(fv reflect.Value, setTo interface{}) error {
if fv.Kind() == reflect.Interface {
fv = fv.Elem()
}
if !fv.CanSet() {
// this shouldn't happened because struct's field value should be settable in general
return fmt.Errorf("field [%v] is not settable", fv.Type())
}
v := reflect.ValueOf(setTo)
t := v.Type()
ft := fv.Type()
switch {
case v.IsZero():
fv.Set(reflect.Zero(ft))
case t.AssignableTo(ft):
fv.Set(v)
case v.Type().ConvertibleTo(ft):
fv.Set(v.Convert(ft))
default:
return fmt.Errorf("value with type [%v] cannot be set to field [%v]", t, ft)
}
return nil
}
func (m *FieldClaimsMapper) toMap(owner interface{}, convert bool) (map[string]interface{}, error) {
m.prepare(owner)
// try to aggregate values from internal Claims interfaces first
var err error
ret := m.aggregateEmbeddedClaims(owner, map[string]interface{}{}, func(i interface{}, claims Claims) (interface{}, bool) {
var values map[string]interface{}
if sc, ok := claims.(claimsMapper); ok {
values, err = sc.toMap(sc, convert)
if err != nil {
return nil, false
}
} else if mc, ok := claims.(MapClaims); ok {
values, err = mc.toMap(convert)
if err != nil {
return nil, false
}
}
aggregated := i.(map[string]interface{})
if values != nil {
for k, v := range values {
aggregated[k] = v
}
}
return aggregated, true
}).(map[string]interface{})
if err != nil {
return nil, err
}
// collect claims from known fields
ov := reflect.ValueOf(owner)
if ov.Kind() == reflect.Ptr {
ov = ov.Elem()
}
for k, index := range m.fields {
fv := ov.FieldByIndex(index)
if fv.IsValid() && !fv.IsZero() {
if convert {
v, e := claimMarshalConvert(fv)
if e != nil {
return nil, e
}
ret[k] = v.Interface()
} else {
ret[k] = fv.Interface()
}
}
}
return ret, nil
}
func (m *FieldClaimsMapper) fromMap(owner interface{}, src map[string]interface{}) error {
m.prepare(owner)
ov := reflect.ValueOf(owner)
if ov.Kind() == reflect.Ptr {
ov = ov.Elem()
}
for k, index := range m.fields {
value, ok := src[k]
if !ok {
continue
}
fv := ov.FieldByIndex(index)
if fv.IsValid() && fv.CanSet() {
// some types requires special conversion
v, e := claimUnmarshalConvert(reflect.ValueOf(value), fv.Type())
if e != nil {
return e
}
fv.Set(v)
}
}
// try set internal Claims interfaces
err,_ := m.aggregateEmbeddedClaims(owner, nil, func(i interface{}, claims Claims) (interface{}, bool) {
if sc, ok := claims.(claimsMapper); ok {
e := sc.fromMap(sc, src)
if e != nil {
return e, false
}
} else if mc, ok := claims.(MapClaims); ok {
e := mc.fromMap(src)
if e != nil {
return e, false
}
}
return nil, true
}).(error)
return err
}
func (m *FieldClaimsMapper) aggregateEmbeddedClaims(owner interface{}, initial interface{}, accumulator accumulator) interface{} {
m.prepare(owner)
if len(m.interfaces) == 0 {
return initial
}
v := reflect.ValueOf(owner)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
i := initial
next := true
for _,index := range m.interfaces {
fv := v.FieldByIndex(index)
if !fv.IsValid() || fv.IsZero() {
continue
}
if claims, ok := fv.Interface().(Claims); ok {
i, next = accumulator(i, claims)
if !next {
break
}
}
}
return i
}
func (m *FieldClaimsMapper) prepare(owner interface{}) {
if m.fields != nil {
return
}
t := reflect.TypeOf(owner)
m.fields = map[string][]int{}
m.populateFieldMap(t, []int{})
m.interfaces = [][]int{}
m.populateInterfaceList(t)
}
// populateFieldMap recursively map fields of given struct type with its claim value, take embedded filed into consideration
func (m *FieldClaimsMapper) populateFieldMap(structType reflect.Type, index []int) {
t := structType
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct || mapperType.AssignableTo(t){
return
}
total := t.NumField()
for i := 0; i < total; i++ {
field := t.Field(i)
if field.Anonymous {
m.populateFieldMap(field.Type, append(index, field.Index...))
continue
}
if claim, ok := field.Tag.Lookup(ClaimTag); ok {
m.fields[claim] = append(index, field.Index...)
}
}
}
// populateInterfaceList find all fields with Claims interface as a type
func (m *FieldClaimsMapper) populateInterfaceList(structType reflect.Type) {
t := structType
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct || mapperType.AssignableTo(t){
return
}
total := t.NumField()
for i := 0; i < total; i++ {
field := t.Field(i)
if field.Type.Kind() == reflect.Interface && claimsType.AssignableTo(field.Type) || field.Type.AssignableTo(mapClaimsType) {
m.interfaces = append(m.interfaces, field.Index)
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/utils"
)
/******************************
OAuth2Request
******************************/
var excludedParameters = utils.NewStringSet(ParameterPassword, ParameterClientSecret)
//goland:noinspection GoNameStartsWithPackageName
type OAuth2Request interface {
Parameters() map[string]string
ClientId() string
Scopes() utils.StringSet
Approved() bool
GrantType() string
RedirectUri() string
ResponseTypes() utils.StringSet
Extensions() map[string]interface{}
NewOAuth2Request(...RequestOptionsFunc) OAuth2Request
}
/******************************
Implementation
******************************/
type RequestDetails struct {
Parameters map[string]string `json:"parameters"`
ClientId string `json:"clientId"`
Scopes utils.StringSet `json:"scope"`
Approved bool `json:"approved"`
GrantType string `json:"grantType"`
RedirectUri string `json:"redirectUri"`
ResponseTypes utils.StringSet `json:"responseTypes"`
Extensions map[string]interface{} `json:"extensions"`
}
type RequestOptionsFunc func(opt *RequestDetails)
type oauth2Request struct {
RequestDetails
}
func NewOAuth2Request(optFuncs ...RequestOptionsFunc) OAuth2Request {
request := oauth2Request{ RequestDetails: RequestDetails{
Parameters: map[string]string{},
Scopes: utils.NewStringSet(),
ResponseTypes: utils.NewStringSet(),
Extensions: map[string]interface{}{},
}}
for _, optFunc := range optFuncs {
optFunc(&request.RequestDetails)
}
for param := range excludedParameters {
delete(request.RequestDetails.Parameters, param)
}
return &request
}
func (r *oauth2Request) Parameters() map[string]string {
return r.RequestDetails.Parameters
}
func (r *oauth2Request) ClientId() string {
return r.RequestDetails.ClientId
}
func (r *oauth2Request) Scopes() utils.StringSet {
return r.RequestDetails.Scopes
}
func (r *oauth2Request) Approved() bool {
return r.RequestDetails.Approved
}
func (r *oauth2Request) GrantType() string {
return r.RequestDetails.GrantType
}
func (r *oauth2Request) RedirectUri() string {
return r.RequestDetails.RedirectUri
}
func (r *oauth2Request) ResponseTypes() utils.StringSet {
return r.RequestDetails.ResponseTypes
}
func (r *oauth2Request) Extensions() map[string]interface{} {
return r.RequestDetails.Extensions
}
func (r *oauth2Request) NewOAuth2Request(additional ...RequestOptionsFunc) OAuth2Request {
all := append([]RequestOptionsFunc{r.copyFunc()}, additional...)
return NewOAuth2Request(all...)
}
func (r *oauth2Request) copyFunc() RequestOptionsFunc {
return func(opt *RequestDetails) {
opt.ClientId = r.RequestDetails.ClientId
opt.Scopes = r.RequestDetails.Scopes.Copy()
opt.Approved = r.RequestDetails.Approved
opt.GrantType = r.RequestDetails.GrantType
opt.RedirectUri = r.RequestDetails.RedirectUri
opt.ResponseTypes = r.RequestDetails.ResponseTypes
for k, v := range r.RequestDetails.Parameters {
opt.Parameters[k] = v
}
for k, v := range r.RequestDetails.Extensions {
opt.Extensions[k] = v
}
}
}
// MarshalJSON json.Marshaler
func (r *oauth2Request) MarshalJSON() ([]byte, error) {
return json.Marshal(r.RequestDetails)
}
// UnmarshalJSON json.Unmarshaler
func (r *oauth2Request) UnmarshalJSON(data []byte) error {
if e := json.Unmarshal(data, &r.RequestDetails); e != nil {
return e
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"github.com/cisco-open/go-lanai/pkg/utils"
"strings"
"time"
)
/*****************************
Abstractions
*****************************/
type TokenType string
const(
TokenTypeBearer = "bearer"
TokenTypeMac = "mac"
TokenTypeBasic = "basic"
)
func (t TokenType) HttpHeader() string {
switch strings.ToLower(string(t)) {
case TokenTypeMac:
return "MAC"
case TokenTypeBasic:
return "Basic"
default:
return "Bearer"
}
}
type Token interface {
Value() string
ExpiryTime() time.Time
Expired() bool
Details() map[string]interface{}
}
type ClaimsContainer interface {
Claims() Claims
SetClaims(claims Claims)
}
type AccessToken interface {
Token
Type() TokenType
IssueTime() time.Time
Scopes() utils.StringSet
RefreshToken() RefreshToken
}
type RefreshToken interface {
Token
WillExpire() bool
}
/*******************************
Common Impl. AccessToken
*******************************/
// DefaultAccessToken implements AccessToken and ClaimsContainer
type DefaultAccessToken struct {
claims Claims
tokenType TokenType
value string
expiryTime time.Time
issueTime time.Time
scopes utils.StringSet
refreshToken *DefaultRefreshToken
details map[string]interface{}
}
func NewDefaultAccessToken(value string) *DefaultAccessToken {
return &DefaultAccessToken{
value: value,
tokenType: TokenTypeBearer,
scopes: utils.NewStringSet(),
issueTime: time.Now(),
details: map[string]interface{}{},
claims: MapClaims{},
}
}
func FromAccessToken(token AccessToken) *DefaultAccessToken {
if t, ok := token.(*DefaultAccessToken); ok {
return &DefaultAccessToken{
value: t.value,
tokenType: t.tokenType,
expiryTime: t.expiryTime,
issueTime: t.issueTime,
scopes: t.scopes.Copy(),
claims: t.claims,
details: copyMap(t.details),
refreshToken: t.refreshToken,
}
}
cp := &DefaultAccessToken{
value: token.Value(),
tokenType: token.Type(),
expiryTime: token.ExpiryTime(),
issueTime: token.IssueTime(),
scopes: token.Scopes().Copy(),
details: copyMap(token.Details()),
}
cp.SetRefreshToken(token.RefreshToken())
return cp
}
// Value implements AccessToken
func (t *DefaultAccessToken) Value() string {
return t.value
}
// Details implements AccessToken
func (t *DefaultAccessToken) Details() map[string]interface{} {
return t.details
}
// Type implements AccessToken
func (t *DefaultAccessToken) Type() TokenType {
return t.tokenType
}
// IssueTime implements AccessToken
func (t *DefaultAccessToken) IssueTime() time.Time {
return t.issueTime
}
// ExpiryTime implements AccessToken
func (t *DefaultAccessToken) ExpiryTime() time.Time {
return t.expiryTime
}
// Expired implements AccessToken
func (t *DefaultAccessToken) Expired() bool {
return !t.expiryTime.IsZero() && t.expiryTime.Before(time.Now())
}
// Scopes implements AccessToken
func (t *DefaultAccessToken) Scopes() utils.StringSet {
return t.scopes
}
// RefreshToken implements AccessToken
func (t *DefaultAccessToken) RefreshToken() RefreshToken {
if t.refreshToken == nil {
return nil
}
return t.refreshToken
}
// Claims implements ClaimsContainer
func (t *DefaultAccessToken) Claims() Claims {
return t.claims
}
// SetClaims implements ClaimsContainer
func (t *DefaultAccessToken) SetClaims(claims Claims) {
t.claims = claims
}
/* Setters */
func (t *DefaultAccessToken) SetValue(v string) *DefaultAccessToken {
t.value = v
return t
}
func (t *DefaultAccessToken) SetIssueTime(v time.Time) *DefaultAccessToken {
t.issueTime = v.UTC()
return t
}
func (t *DefaultAccessToken) SetExpireTime(v time.Time) *DefaultAccessToken {
t.expiryTime = v.UTC()
return t
}
func (t *DefaultAccessToken) SetRefreshToken(v RefreshToken) *DefaultAccessToken {
if refresh, ok := v.(*DefaultRefreshToken); ok {
t.refreshToken = refresh
} else if v == nil {
t.refreshToken = nil
} else {
t.refreshToken = FromRefreshToken(v)
}
return t
}
func (t *DefaultAccessToken) SetScopes(scopes utils.StringSet) *DefaultAccessToken {
t.scopes = scopes.Copy()
return t
}
func (t *DefaultAccessToken) AddScopes(scopes...string) *DefaultAccessToken {
t.scopes.Add(scopes...)
return t
}
func (t *DefaultAccessToken) RemoveScopes(scopes...string) *DefaultAccessToken {
t.scopes.Remove(scopes...)
return t
}
func (t *DefaultAccessToken) PutDetails(key string, value interface{}) *DefaultAccessToken {
if value == nil {
delete(t.details, key)
} else {
t.details[key] = value
}
return t
}
/********************************
Common Impl. RefreshToken
********************************/
// DefaultRefreshToken implements RefreshToken and ClaimsContainer
type DefaultRefreshToken struct {
claims Claims
value string
expiryTime time.Time
details map[string]interface{}
}
func NewDefaultRefreshToken(value string) *DefaultRefreshToken {
return &DefaultRefreshToken{
value: value,
details: map[string]interface{}{},
claims: MapClaims{},
}
}
func FromRefreshToken(token RefreshToken) *DefaultRefreshToken {
if t, ok := token.(*DefaultRefreshToken); ok {
return &DefaultRefreshToken{
value: t.value,
details: copyMap(t.details),
claims: t.claims,
}
}
return &DefaultRefreshToken{
value: token.Value(),
details: copyMap(token.Details()),
}
}
// Value implements RefreshToken
func (t *DefaultRefreshToken) Value() string {
return t.value
}
// Details implements RefreshToken
func (t *DefaultRefreshToken) Details() map[string]interface{} {
return t.details
}
// ExpiryTime implements RefreshToken
func (t *DefaultRefreshToken) ExpiryTime() time.Time {
return t.expiryTime
}
// Expired implements RefreshToken
func (t *DefaultRefreshToken) Expired() bool {
return !t.expiryTime.IsZero() && t.expiryTime.Before(time.Now())
}
// WillExpire implements RefreshToken
func (t *DefaultRefreshToken) WillExpire() bool {
return !t.expiryTime.IsZero()
}
// Claims implements ClaimsContainer
func (t *DefaultRefreshToken) Claims() Claims {
return t.claims
}
// SetClaims implements ClaimsContainer
func (t *DefaultRefreshToken) SetClaims(claims Claims) {
t.claims = claims
}
/* Setters */
func (t *DefaultRefreshToken) SetValue(v string) *DefaultRefreshToken {
t.value = v
return t
}
func (t *DefaultRefreshToken) SetExpireTime(v time.Time) *DefaultRefreshToken {
t.expiryTime = v.UTC()
return t
}
func (t *DefaultRefreshToken) PutDetails(key string, value interface{}) *DefaultRefreshToken {
if value == nil {
delete(t.details, key)
} else {
t.details[key] = value
}
return t
}
/********************************
Helpers
********************************/
func copyMap(src map[string]interface{}) map[string]interface{} {
dest := map[string]interface{}{}
for k,v := range src {
dest[k] = v
}
return dest
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"reflect"
"strings"
"time"
)
const (
errTmplFieldExpectString = `invalid field type. expected string`
errTmplFieldExpectInt = `invalid field type. expected integer`
errTmplFieldExpectISO8601 = `invalid field format. expected ISO8601 formatted string`
)
type valueConverterFunc func(v interface{}) (reflect.Value, error)
/************************
DefaultAccessToken
************************/
var accessTokenIgnoredDetails = utils.NewStringSet(
JsonFieldAccessTokenValue, JsonFieldTokenType, JsonFieldScope,
JsonFieldExpiryTime, JsonFieldIssueTime, JsonFieldExpiresIn, JsonFieldRefreshTokenValue)
var scopeSeparator = " "
// MarshalJSON implements json.Marshaler
func (t *DefaultAccessToken) MarshalJSON() ([]byte, error) {
data := map[string]interface{}{}
for k, v := range t.details {
data[k] = v
}
data[JsonFieldAccessTokenValue] = t.value
data[JsonFieldTokenType] = t.tokenType
data[JsonFieldScope] = strings.Join(t.scopes.Values(), scopeSeparator)
data[JsonFieldIssueTime] = t.issueTime.Format(utils.ISO8601Seconds)
if !t.expiryTime.IsZero() {
data[JsonFieldExpiryTime] = t.expiryTime.Format(utils.ISO8601Seconds)
data[JsonFieldExpiresIn] = int(t.expiryTime.Sub(time.Now()).Seconds())
}
if t.refreshToken != nil {
data[JsonFieldRefreshTokenValue] = t.refreshToken
}
return json.Marshal(data)
}
// UnmarshalJSON implements json.Unmarshaler
func (t *DefaultAccessToken) UnmarshalJSON(data []byte) error {
parsed := map[string]interface{}{}
if err := json.Unmarshal(data, &parsed); err != nil {
return err
}
if err := extractField(parsed, JsonFieldAccessTokenValue, true, &t.value, anyToString); err != nil {
return err
}
if err := extractField(parsed, JsonFieldTokenType, true, &t.tokenType, stringToTokenType); err != nil {
return err
}
if err := extractField(parsed, JsonFieldScope, true, &t.scopes, stringSliceToStringSet); err != nil {
return err
}
// issue time is optional
if err := extractField(parsed, JsonFieldIssueTime, false, &t.issueTime, expiryToTime); err != nil {
return err
}
// default to parse expiry time from JsonFieldExpiryTime field, fall back to JsonFieldExpiresIn
// sets required to true so we can fallback to JsonFieldExpiresIn.
if err := extractField(parsed, JsonFieldExpiryTime, true, &t.expiryTime, expiryToTime); err != nil {
if err := extractField(parsed, JsonFieldExpiresIn, true, &t.expiryTime, expireInToTimeConverter(t.issueTime)); err != nil {
return err
}
}
if err := extractField(parsed, JsonFieldRefreshTokenValue, false, &t.refreshToken, stringToRefreshToken); err != nil {
return err
}
// put the rest of fields to details
for k, v := range parsed {
if !accessTokenIgnoredDetails.Has(k) {
t.details[k] = v
}
}
return nil
}
/************************
DefaultRefreshToken
************************/
// MarshalJSON implements json.Marshaler, only DefaultRefreshToken.value is serialized
func (t *DefaultRefreshToken) MarshalJSON() ([]byte, error) {
return json.Marshal(t.value)
}
// UnmarshalJSON implements json.Unmarshaler
func (t *DefaultRefreshToken) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &t.value)
}
/************************
Helpers
************************/
func extractField(data map[string]interface{}, field string, required bool, destPtr interface{}, converter valueConverterFunc) error {
v, ok := data[field]
switch {
case !ok && required:
return fmt.Errorf("cannot find required field [%s]", field)
case !ok:
return nil
}
value, err := converter(v)
if err != nil {
return fmt.Errorf("cannot parse field [%s]: %s", field, err.Error())
}
dest := reflect.ValueOf(destPtr)
if !dest.CanSet() {
dest = dest.Elem()
}
dest.Set(value)
return nil
}
func anyToString(v interface{}) (reflect.Value, error) {
_, ok := v.(string)
if !ok {
return reflect.Value{}, fmt.Errorf(errTmplFieldExpectString)
}
return reflect.ValueOf(v), nil
}
func stringToTokenType(v interface{}) (reflect.Value, error) {
s, ok := v.(string)
if !ok {
return reflect.Value{}, fmt.Errorf(errTmplFieldExpectString)
}
return reflect.ValueOf(TokenType(s)), nil
}
func stringSliceToStringSet(v interface{}) (reflect.Value, error) {
stringSlice, ok := v.(string)
if !ok {
return reflect.Value{}, fmt.Errorf(errTmplFieldExpectString)
}
slice := strings.Split(stringSlice, scopeSeparator)
scopes := utils.NewStringSet()
for _, s := range slice {
scopes.Add(s)
}
return reflect.ValueOf(scopes), nil
}
func expiryToTime(v interface{}) (reflect.Value, error) {
str, ok := v.(string)
if !ok {
return reflect.Value{}, fmt.Errorf(errTmplFieldExpectISO8601)
}
if t := utils.ParseTimeISO8601(str); !t.IsZero() {
return reflect.ValueOf(t), nil
} else if t := utils.ParseTime(utils.ISO8601Milliseconds, str); !t.IsZero() {
return reflect.ValueOf(t), nil
}
return reflect.Value{}, fmt.Errorf(errTmplFieldExpectISO8601)
}
func expireInToTimeConverter(issueTime time.Time) valueConverterFunc {
return func(v interface{}) (reflect.Value, error) {
secs, ok := v.(float64)
if !ok {
return reflect.Value{}, fmt.Errorf(errTmplFieldExpectInt)
}
if issueTime.IsZero() {
issueTime = time.Now()
}
t := issueTime.Add(time.Duration(secs) * time.Second)
return reflect.ValueOf(t), nil
}
}
func stringToRefreshToken(v interface{}) (reflect.Value, error) {
s, ok := v.(string)
if !ok {
return reflect.Value{}, fmt.Errorf(errTmplFieldExpectString)
}
return reflect.ValueOf(NewDefaultRefreshToken(s)), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import "context"
const (
_ TokenHint = iota
TokenHintAccessToken
TokenHintRefreshToken
)
type TokenHint int
func (h TokenHint) String() string {
switch h {
case TokenHintAccessToken:
return "access_token"
case TokenHintRefreshToken:
return "refresh_token"
default:
return "unknown"
}
}
type TokenStoreReader interface {
// ReadAuthentication load associated Authentication with Token.
// Token can be AccessToken or RefreshToken
ReadAuthentication(ctx context.Context, tokenValue string, hint TokenHint) (Authentication, error)
// ReadAccessToken load AccessToken with given value.
// If the AccessToken is not associated with a valid security.ContextDetails (revoked), it returns error
ReadAccessToken(ctx context.Context, value string) (AccessToken, error)
// ReadRefreshToken load RefreshToken with given value.
// this method does not imply any revocation status. it depends on implementation
ReadRefreshToken(ctx context.Context, value string) (RefreshToken, error)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"bytes"
"encoding/gob"
"encoding/json"
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
"net/http"
)
// All "SubType" values are used as mask
// sub types of security.ErrorTypeCodeOAuth2
const (
_ = iota
ErrorSubTypeCodeOAuth2Internal = security.ErrorTypeCodeOAuth2 + iota<<errorutils.ErrorSubTypeOffset
ErrorSubTypeCodeOAuth2ClientAuth
ErrorSubTypeCodeOAuth2Authorize
ErrorSubTypeCodeOAuth2Grant
ErrorSubTypeCodeOAuth2Res
)
// ErrorSubTypeCodeOAuth2Internal
const (
_ = ErrorSubTypeCodeOAuth2Internal + iota
ErrorCodeOAuth2InternalGeneral
)
// ErrorSubTypeCodeOAuth2ClientAuth
const (
_ = ErrorSubTypeCodeOAuth2ClientAuth + iota
ErrorCodeClientNotFound
ErrorCodeInvalidClient
)
// ErrorSubTypeCodeOAuth2Authorize
const (
_ = ErrorSubTypeCodeOAuth2Authorize + iota
ErrorCodeInvalidAuthorizeRequest
ErrorCodeInvalidResponseType
ErrorCodeInvalidRedirectUri
ErrorCodeAccessRejected
ErrorCodeOpenIDExt
)
// ErrorSubTypeCodeOAuth2Grant
const (
_ = ErrorSubTypeCodeOAuth2Grant + iota
ErrorCodeGranterNotAvailable
ErrorCodeUnauthorizedClient // grant type is not allowed for client
ErrorCodeInvalidTokenRequest
ErrorCodeInvalidGrant
ErrorCodeInvalidScope
ErrorCodeUnsupportedTokenType
ErrorCodeGeneric
)
// ErrorSubTypeCodeOAuth2Res
const (
_ = ErrorSubTypeCodeOAuth2Res + iota
ErrorCodeInvalidAccessToken
ErrorCodeInsufficientScope
ErrorCodeResourceServerGeneral // this should only be used for error deserialization
)
// ErrorTypes, can be used in errors.Is
//goland:noinspection GoUnusedGlobalVariable
var (
ErrorTypeOAuth2 = security.NewErrorType(security.ErrorTypeCodeOAuth2, errors.New("error type: oauth2"))
ErrorSubTypeOAuth2Internal = security.NewErrorSubType(ErrorSubTypeCodeOAuth2Internal, errors.New("error sub-type: internal"))
ErrorSubTypeOAuth2ClientAuth = security.NewErrorSubType(ErrorSubTypeCodeOAuth2ClientAuth, errors.New("error sub-type: oauth2 client auth"))
ErrorSubTypeOAuth2Authorize = security.NewErrorSubType(ErrorSubTypeCodeOAuth2Authorize, errors.New("error sub-type: oauth2 auth"))
ErrorSubTypeOAuth2Grant = security.NewErrorSubType(ErrorSubTypeCodeOAuth2Grant, errors.New("error sub-type: oauth2 grant"))
ErrorSubTypeOAuth2Res = security.NewErrorSubType(ErrorSubTypeCodeOAuth2Res, errors.New("error sub-type: oauth2 resource"))
)
/************************
Error EC
*************************/
//goland:noinspection GoCommentStart
const (
// https://tools.ietf.org/html/rfc6749#section-4.1.2.1
ErrorTranslationInvalidRequest = "invalid_request"
ErrorTranslationUnauthorizedClient = "unauthorized_client"
ErrorTranslationAccessDenied = "access_denied"
ErrorTranslationInvalidResponseType = "unsupported_response_type"
ErrorTranslationInvalidScope = "invalid_scope"
ErrorTranslationInternal = "server_error"
ErrorTranslationInternalNA = "temporarily_unavailable"
// https://tools.ietf.org/html/rfc6749#section-5.2
ErrorTranslationInvalidClient = "invalid_client"
ErrorTranslationInvalidGrant = "invalid_grant"
ErrorTranslationGrantNotSupported = "unsupported_grant_type"
// commonly used (no RFC reference for now)
ErrorTranslationInsufficientScope = "insufficient_scope"
ErrorTranslationInvalidToken = "invalid_token"
ErrorTranslationRedirectMismatch = "redirect_uri_mismatch"
// https://tools.ietf.org/html/rfc7009#section-4.1.1
ErrorTranslationUnsupportedTokenType = "unsupported_token_type"
// https://openid.net/specs/openid-connect-core-1_0.html#AuthError
ErrorTranslationInteractionRequired = "interaction_required"
ErrorTranslationLoginRequired = "login_required"
ErrorTranslationAcctSelectRequired = "account_selection_required"
ErrorTranslationConsentRequired = "consent_required"
ErrorTranslationInvalidRequestURI = "invalid_request_uri"
ErrorTranslationInvalidRequestObj = "invalid_request_object"
ErrorTranslationRequestUnsupported = "request_not_supported"
ErrorTranslationRequestURIUnsupported = "request_uri_not_supported"
ErrorTranslationRegistrationUnsupported = "registration_not_supported"
//ErrorTranslation = ""
)
/************************
Extensions
*************************/
//goland:noinspection GoNameStartsWithPackageName
type OAuth2ErrorTranslator interface {
error
TranslateErrorCode() string
TranslateStatusCode() int
}
// OAuth2Error extends security.CodedError, and implements:
// - OAuth2ErrorTranslator
// - json.Marshaler
// - json.Unmarshaler
// - web.Headerer
// - web.StatusCoder
// - encoding.BinaryMarshaler
// - encoding.BinaryUnmarshaler
//goland:noinspection GoNameStartsWithPackageName
type OAuth2Error struct {
security.CodedError
EC string // oauth error code
SC int // status code
}
func (e *OAuth2Error) StatusCode() int {
return e.SC
}
func (e *OAuth2Error) Headers() http.Header {
header := http.Header{}
header.Add("Cache-Control", "no-store")
header.Add("Pragma", "no-cache")
return header
}
func (e *OAuth2Error) TranslateErrorCode() string {
return e.EC
}
func (e *OAuth2Error) TranslateStatusCode() int {
return e.SC
}
// MarshalJSON implements json.Marshaler
func (e *OAuth2Error) MarshalJSON() ([]byte, error) {
data := map[string]string{
ParameterError: e.EC,
ParameterErrorDescription: e.Error(),
}
return json.Marshal(data)
}
// UnmarshalJSON implements json.Unmarshaler
// Note: JSON doesn't include internal code error. So reconstruct error from JSON is not possible.
// Unmarshaler can only be used for opaque token checking HTTP call
func (e *OAuth2Error) UnmarshalJSON(data []byte) error {
values := map[string]string{}
if e := json.Unmarshal(data, &values); e != nil {
return e
}
e.EC = values[ParameterError]
desc := values[ParameterErrorDescription]
e.CodedError = *security.NewCodedError(ErrorCodeResourceServerGeneral, desc)
return nil
}
type oauth2ErrorCarrier struct {
CodedError security.CodedError
EC string // oauth error code
SC int // status code
}
// MarshalBinary implements encoding.BinaryMarshaler interface
func (e OAuth2Error) MarshalBinary() ([]byte, error) {
buffer := bytes.NewBuffer([]byte{})
encoder := gob.NewEncoder(buffer)
carrier := oauth2ErrorCarrier{
CodedError: e.CodedError,
EC: e.EC,
SC: e.SC,
}
if e := encoder.Encode(&carrier); e != nil {
return nil, e
}
return buffer.Bytes(), nil
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler interface
func (e *OAuth2Error) UnmarshalBinary(data []byte) error {
buffer := bytes.NewBuffer(data)
decoder := gob.NewDecoder(buffer)
carrier := oauth2ErrorCarrier{}
if e := decoder.Decode(&carrier); e != nil {
return e
}
*e = OAuth2Error{
CodedError: carrier.CodedError,
EC: carrier.EC,
SC: carrier.SC,
}
return nil
}
/************************
Constructors
*************************/
func NewOAuth2Error(code int64, e interface{}, oauth2Code string, sc int, causes ...interface{}) *OAuth2Error {
embedded := security.NewCodedError(code, e, causes...)
return &OAuth2Error{
CodedError: *embedded,
EC: oauth2Code,
SC: sc,
}
}
/* OAuth2Internal family */
func NewInternalError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeOAuth2InternalGeneral, value,
ErrorTranslationInternal, http.StatusBadRequest,
causes...)
}
func NewInternalUnavailableError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeOAuth2InternalGeneral, value,
ErrorTranslationInternalNA, http.StatusBadRequest,
causes...)
}
/* OAuth2Auth family */
func NewGranterNotAvailableError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeGranterNotAvailable, value,
ErrorTranslationGrantNotSupported, http.StatusBadRequest,
causes...)
}
func NewInvalidTokenRequestError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInvalidTokenRequest, value,
ErrorTranslationInvalidRequest, http.StatusBadRequest,
causes...)
}
func NewInvalidClientError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInvalidClient, value,
ErrorTranslationInvalidClient, http.StatusUnauthorized,
causes...)
}
func NewClientNotFoundError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeClientNotFound, value,
ErrorTranslationInvalidClient, http.StatusUnauthorized,
causes...)
}
func NewUnauthorizedClientError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeUnauthorizedClient, value,
ErrorTranslationUnauthorizedClient, http.StatusBadRequest,
causes...)
}
func NewInvalidGrantError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInvalidGrant, value,
ErrorTranslationInvalidGrant, http.StatusBadRequest,
causes...)
}
func NewInvalidScopeError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInvalidScope, value,
ErrorTranslationInvalidScope, http.StatusBadRequest,
causes...)
}
func NewUnsupportedTokenTypeError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeUnsupportedTokenType, value,
ErrorTranslationUnsupportedTokenType, http.StatusBadRequest,
causes...)
}
func NewGenericError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeGeneric, value,
ErrorTranslationInvalidRequest, http.StatusBadRequest,
causes...)
}
func NewInvalidAuthorizeRequestError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInvalidAuthorizeRequest, value,
ErrorTranslationInvalidRequest, http.StatusBadRequest,
causes...)
}
func NewInvalidRedirectUriError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInvalidRedirectUri, value,
ErrorTranslationRedirectMismatch, http.StatusBadRequest,
causes...)
}
func NewInvalidResponseTypeError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInvalidResponseType, value,
ErrorTranslationInvalidResponseType, http.StatusBadRequest,
causes...)
}
func NewAccessRejectedError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeAccessRejected, value,
ErrorTranslationAccessDenied, http.StatusBadRequest,
causes...)
}
/* OAuth2Res family */
func NewInvalidAccessTokenError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInvalidAccessToken, value,
ErrorTranslationInvalidToken, http.StatusUnauthorized,
causes...)
}
func NewInsufficientScopeError(value interface{}, causes ...interface{}) error {
return NewOAuth2Error(ErrorCodeInsufficientScope, value,
ErrorTranslationInsufficientScope, http.StatusForbidden,
causes...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"encoding/json"
"errors"
)
/*********************
Implements
*********************/
// jwtGoCompatibleClaims implements jwt.Claims and has its own json serialization/deserialization
type jwtGoCompatibleClaims struct {
claims interface{}
}
func (c *jwtGoCompatibleClaims) Valid() error {
if c.claims == nil {
return errors.New("embedded claims are nil")
}
return nil
}
func (c *jwtGoCompatibleClaims) MarshalJSON() ([]byte, error) {
return json.Marshal(c.claims)
}
func (c *jwtGoCompatibleClaims) UnmarshalJSON(bytes []byte) error {
return json.Unmarshal(bytes, c.claims)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"encoding/base64"
"encoding/json"
"github.com/golang-jwt/jwt/v4"
"strings"
)
// ParseJwtHeaders extract JWT's headers without verifying the token
func ParseJwtHeaders(jwtValue string) (map[string]interface{}, error) {
parts := strings.Split(jwtValue, ".")
if len(parts) != 3 {
return nil, jwt.NewValidationError("token contains an invalid number of segments", jwt.ValidationErrorMalformed)
}
// decode header
encoded := parts[0]
b64 := base64.RawURLEncoding
if l := len(encoded) % 4; l > 0 {
encoded += strings.Repeat("=", 4-l)
b64 = base64.URLEncoding
}
headerBytes, e := b64.DecodeString(encoded)
if e != nil {
return nil, &jwt.ValidationError{Inner: e, Errors: jwt.ValidationErrorMalformed}
}
// unmarshal header
var headers map[string]interface{}
if e := json.Unmarshal(headerBytes, &headers); e != nil {
return nil, &jwt.ValidationError{Inner: e, Errors: jwt.ValidationErrorMalformed}
}
return headers, nil
}
//func printPrivateKey(key *rsa.PrivateKey) string {
// //bytes := x509.MarshalPKCS1PrivateKey(key)
// bytes, _ := x509.MarshalPKCS8PrivateKey(key)
// return base64.StdEncoding.EncodeToString(bytes)
//}
//
//func printPublicKey(key *rsa.PublicKey) string {
// bytes := x509.MarshalPKCS1PublicKey(key)
// return base64.StdEncoding.EncodeToString(bytes)
//}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/golang-jwt/jwt/v4"
)
/*********************
Abstract
*********************/
//goland:noinspection GoNameStartsWithPackageName
type JwtDecoder interface {
Decode(ctx context.Context, token string) (oauth2.Claims, error)
DecodeWithClaims(ctx context.Context, token string, claims interface{}) error
}
/*********************
Constructors
*********************/
var (
SymmetricSigningMethods = []jwt.SigningMethod{
jwt.SigningMethodHS256, jwt.SigningMethodHS384, jwt.SigningMethodHS512,
}
AsymmetricSigningMethods = []jwt.SigningMethod{
jwt.SigningMethodRS256, jwt.SigningMethodRS384, jwt.SigningMethodRS512,
jwt.SigningMethodES256, jwt.SigningMethodES384, jwt.SigningMethodES512,
jwt.SigningMethodPS256, jwt.SigningMethodPS384, jwt.SigningMethodPS512,
jwt.SigningMethodEdDSA,
}
SupportedSigningMethods = append(AsymmetricSigningMethods, SymmetricSigningMethods...)
)
type VerifyOptions func(opt *VerifyOption)
type VerifyOption struct {
JwkStore JwkStore
JwkName string
Methods []jwt.SigningMethod
ParserOptions []jwt.ParserOption
}
// VerifyWithJwkStore is a VerifyOptions that set JwkStore and default key name to use when verifying.
// the provided key name is used as fallback if the to-be-verified JWT doesn't have "kid" in header
func VerifyWithJwkStore(store JwkStore, jwkName string) VerifyOptions {
return func(opt *VerifyOption) {
opt.JwkStore = store
opt.JwkName = jwkName
}
}
// VerifyWithMethods is a VerifyOptions that specify all allowed signing method ("alg" header).
// By default, it accepts all available signing methods except for plaintext JWT.
func VerifyWithMethods(methods ...jwt.SigningMethod) VerifyOptions {
return func(opt *VerifyOption) {
opt.Methods = methods
}
}
func NewSignedJwtDecoder(opts ...VerifyOptions) *SignedJwtDecoder {
opt := VerifyOption{
Methods: AsymmetricSigningMethods,
ParserOptions: []jwt.ParserOption{jwt.WithoutClaimsValidation()},
}
for _, fn := range opts {
fn(&opt)
}
methods := make([]string, len(opt.Methods))
for i := range opt.Methods {
methods[i] = opt.Methods[i].Alg()
}
parserOpts := append(opt.ParserOptions, jwt.WithValidMethods(methods))
return &SignedJwtDecoder{
jwkName: opt.JwkName,
jwkStore: opt.JwkStore,
parser: jwt.NewParser(parserOpts...),
}
}
/*********************
Implements
*********************/
// SignedJwtDecoder implements JwtEncoder
type SignedJwtDecoder struct {
jwkName string
jwkStore JwkStore
parser *jwt.Parser
}
func (dec *SignedJwtDecoder) Decode(ctx context.Context, tokenString string) (oauth2.Claims, error) {
claims := oauth2.MapClaims{}
if e := dec.DecodeWithClaims(ctx, tokenString, &claims); e != nil {
return nil, e
}
return claims, nil
}
func (dec *SignedJwtDecoder) DecodeWithClaims(ctx context.Context, tokenString string, claims interface{}) (err error) {
// type checks
switch claims.(type) {
case jwt.Claims:
_, err = dec.parser.ParseWithClaims(tokenString, claims.(jwt.Claims), dec.keyFunc(ctx))
default:
compatible := jwtGoCompatibleClaims{
claims: claims,
}
_, err = dec.parser.ParseWithClaims(tokenString, &compatible, dec.keyFunc(ctx))
}
return
}
func (dec *SignedJwtDecoder) keyFunc(ctx context.Context) jwt.Keyfunc {
return func(unverified *jwt.Token) (interface{}, error) {
var jwk Jwk
var e error
switch kid, ok := unverified.Header[JwtHeaderKid].(string); {
case ok:
jwk, e = dec.jwkStore.LoadByKid(ctx, kid)
default:
jwk, e = dec.jwkStore.LoadByName(ctx, dec.jwkName)
}
if e != nil {
return nil, e
}
return jwk.Public(), nil
}
}
// PlaintextJwtDecoder implements JwtEncoder
type PlaintextJwtDecoder struct {
jwkName string
jwkStore JwkStore
parser *jwt.Parser
}
func NewPlaintextJwtDecoder() *PlaintextJwtDecoder {
parser := jwt.NewParser(jwt.WithoutClaimsValidation(), jwt.WithValidMethods([]string{jwt.SigningMethodNone.Alg()}))
return &PlaintextJwtDecoder{
parser: parser,
}
}
func (dec *PlaintextJwtDecoder) Decode(ctx context.Context, tokenString string) (oauth2.Claims, error) {
claims := oauth2.MapClaims{}
if e := dec.DecodeWithClaims(ctx, tokenString, &claims); e != nil {
return nil, e
}
return claims, nil
}
func (dec *PlaintextJwtDecoder) DecodeWithClaims(_ context.Context, tokenString string, claims interface{}) (err error) {
// type checks
switch claims.(type) {
case jwt.Claims:
_, err = dec.parser.ParseWithClaims(tokenString, claims.(jwt.Claims), dec.keyFunc)
default:
compatible := jwtGoCompatibleClaims{
claims: claims,
}
_, err = dec.parser.ParseWithClaims(tokenString, &compatible, dec.keyFunc)
}
return
}
func (dec *PlaintextJwtDecoder) keyFunc(unverified *jwt.Token) (interface{}, error) {
switch typ, ok := unverified.Header[JwtHeaderAlgorithm].(string); {
case ok && typ == "none":
return jwt.UnsafeAllowNoneSignatureType, nil
default:
return nil, fmt.Errorf("unsupported alg")
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"context"
"fmt"
"github.com/golang-jwt/jwt/v4"
)
/*********************
Abstract
*********************/
//goland:noinspection GoNameStartsWithPackageName
type JwtEncoder interface {
Encode(ctx context.Context, claims interface{}) (string, error)
}
/*********************
Constructors
*********************/
type SigningOptions func(opt *SigningOption)
type SigningOption struct {
JwkStore JwkStore
JwkName string
Method jwt.SigningMethod
}
// SignWithJwkStore is a SigningOptions that set JwkStore and key name to use when signing
func SignWithJwkStore(store JwkStore, jwkName string) SigningOptions {
return func(opt *SigningOption) {
opt.JwkStore = store
opt.JwkName = jwkName
}
}
// SignWithMethod is SigningOptions that specify the method to use. When set to nil, the encoder would
// attempt to use the private key type to resolve signing method.
func SignWithMethod(method jwt.SigningMethod) SigningOptions {
return func(opt *SigningOption) {
opt.Method = method
}
}
// NewSignedJwtEncoder create a JwtEncoder that sign JWT with provided method.
// Depending on the sign method, provided JwkStore should supply proper private keys.
// Note: When using HS algorithms, the HMAC secret is treated as both public and private key,
// and it would be exposed via JWKS endpoint. It is service implementer's responsibility to
// protect the JWKS endpoint to prevent accidental leaking of HMAC secret.
func NewSignedJwtEncoder(opts ...SigningOptions) *SignedJwtEncoder {
opt := SigningOption{
Method: jwt.SigningMethodRS256,
}
for _, fn := range opts {
fn(&opt)
}
return &SignedJwtEncoder{
jwkName: opt.JwkName,
jwkStore: opt.JwkStore,
method: opt.Method,
}
}
/*********************
Implements
*********************/
// SignedJwtEncoder implements JwtEncoder. It encodes claims with crypto signature of choice.
// Encoder may return error if private key is not compatible with signing method
type SignedJwtEncoder struct {
jwkName string
jwkStore JwkStore
method jwt.SigningMethod
}
func (enc *SignedJwtEncoder) Encode(ctx context.Context, claims interface{}) (string, error) {
// choose PrivateKey to use
jwk, e := enc.findJwk(ctx)
if e != nil {
return "", e
}
// resolve signing method
method := enc.method
if method == nil {
if method, e = resolveSigningMethod(jwk.Private()); e != nil {
return "", e
}
}
// type checks
var token *jwt.Token
switch claims.(type) {
case jwt.Claims:
token = jwt.NewWithClaims(method, claims.(jwt.Claims))
default:
token = jwt.NewWithClaims(method, &jwtGoCompatibleClaims{claims: claims})
}
// jwk.Name() could be an alias for more than one kid to support rotation.
//
// We expect the store implementation to return jwk whose ID is the same as its name if the store only has
// one key for that name (i.e. no-rotation), and intend to pass the decoder the key out of band.
// In this case, we don't need to set kid in the header, because we expect the decoder side can get this key
// because there is no ambiguity.
//
// We expect the store to return jwk whose ID is not the same as its name if the store has multiple key for that name
// (i.e. rotation).
// In this case, we need to set kid in the header.
if jwk.Id() != enc.jwkName {
token.Header[JwtHeaderKid] = jwk.Id()
}
return token.SignedString(jwk.Private())
}
func (enc *SignedJwtEncoder) findJwk(ctx context.Context) (PrivateJwk, error) {
if jwk, e := enc.jwkStore.LoadByName(ctx, enc.jwkName); e != nil {
return nil, e
} else if private, ok := jwk.(PrivateJwk); !ok {
return nil, fmt.Errorf("JWK with name[%s] doesn't have private key", enc.jwkName)
} else {
return private, nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"context"
"crypto"
"encoding/json"
"reflect"
)
/*********************
Abstraction
*********************/
type Jwk interface {
Id() string
Name() string
Public() crypto.PublicKey
}
type PrivateJwk interface {
Jwk
Private() crypto.PrivateKey
}
type JwkStore interface {
// LoadByKid returns the JWK associated with given KID.
// This method is usually used when decoding/verifiying JWT token
LoadByKid(ctx context.Context, kid string) (Jwk, error)
// LoadByName returns the JWK associated with given name.
// The method might return different JWK for same name, if the store is also support rotation
// This method is usually used when encoding/encrypt JWT token
// Note: if the store does not support rotation (i.e. it doest not implement JwkRotator),
// this store could use the name as the jwk id. Doing so would allow the encoder to not
// add a "kid" header to the JWT token. This allows the use case where the JWT key is agreed upon by
// both the encoder and decoder through an out-of-band mechanism without using "kid".
// See the comment in SignedJwtEncoder.Encode for more details
LoadByName(ctx context.Context, name string) (Jwk, error)
// LoadAll return all JWK with given names. If name is not provided, all JWK is returned
LoadAll(ctx context.Context, names ...string) ([]Jwk, error)
}
type JwkRotator interface {
JwkStore
// Rotate change JWK of given name to next candicate
Rotate(ctx context.Context, name string) error
}
/*********************
Implements Base
*********************/
// GenericJwk implements Jwk
type GenericJwk struct {
kid string
name string
public crypto.PublicKey
}
func (k *GenericJwk) Id() string {
return k.kid
}
func (k *GenericJwk) Name() string {
return k.name
}
func (k *GenericJwk) Public() crypto.PublicKey {
return k.public
}
func (k *GenericJwk) MarshalJSON() ([]byte, error) {
return marshalJwk(k)
}
func (k *GenericJwk) UnmarshalJSON(data []byte) error {
jwk, e := unmarshalJwk(data)
if e != nil {
return e
}
switch v := jwk.(type) {
case *GenericJwk:
*k = *v
default:
*k = GenericJwk{kid: jwk.Id(), name: jwk.Name(), public: jwk.Public()}
}
return nil
}
// GenericPrivateJwk implements Jwk and PrivateJwk
type GenericPrivateJwk struct {
GenericJwk
private crypto.PrivateKey
}
func (k *GenericPrivateJwk) Private() crypto.PrivateKey {
return k.private
}
/*********************
Constructors
*********************/
var typeOfBytes = reflect.TypeOf((*[]byte)(nil)).Elem()
type publicKey interface {
Equal(x crypto.PublicKey) bool
}
type privateKey interface {
Public() crypto.PublicKey
Equal(x crypto.PrivateKey) bool
}
// NewJwk new Jwk with specified public key
// Supported public key types:
// - *rsa.PublicKey
// - *ecdsa.PublicKey
// - ed25519.PublicKey
// - []byte (MAC secret)
// - any key implementing:
// interface{
// Equal(x crypto.PublicKey) bool
// }
func NewJwk(kid string, name string, pubKey crypto.PublicKey) Jwk {
return &GenericJwk{
kid: kid,
name: name,
public: pubKey,
}
}
// ParseJwk parse Jwk from JSON as specified in RFC 7517 and RFC 7518.
// Note: Private key information is ignored in the parsed Jwk.
// Supported public key types:
// - *rsa.PublicKey (kty = RSA)
// - *ecdsa.PublicKey (kty = EC)
// - ed25519.PublicKey (kty = OKP)
// - []byte (symmetric key, e.g. MAC secret) ((kty = oct)
//
// See: RFC7517 https://datatracker.ietf.org/doc/html/rfc7517
// See: RFC7518 https://datatracker.ietf.org/doc/html/rfc7518
func ParseJwk(jsonData []byte) (Jwk, error) {
var jwk GenericJwk
if e := json.Unmarshal(jsonData, &jwk); e != nil {
return nil, e
}
return &jwk, nil
}
// NewPrivateJwk new PrivateJwk with specified private key
// Supported private key types:
// - *rsa.PrivateKey
// - *ecdsa.PrivateKey
// - ed25519.PrivateKey
// - []byte (MAC secret)
// - any key implementing:
// interface{
// Public() crypto.PublicKey
// Equal(x crypto.PrivateKey) bool
// }
func NewPrivateJwk(kid string, name string, privKey crypto.PrivateKey) PrivateJwk {
var pubKey crypto.PublicKey
switch v := privKey.(type) {
case privateKey:
pubKey = v.Public()
default:
// HMAC secret
if rv := reflect.ValueOf(privKey); rv.CanConvert(typeOfBytes) {
privKey = rv.Convert(typeOfBytes).Interface()
pubKey = privKey
}
}
return &GenericPrivateJwk{
GenericJwk: GenericJwk{
kid: kid,
name: name,
public: pubKey,
},
private: privKey,
}
}
package jwt
import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"math/big"
)
const (
JwkTypeEC = `EC`
JwkTypeRSA = `RSA`
JwkTypeOctet = `oct`
JwkTypeEdDSA = `OKP`
)
func marshalJwk(jwk Jwk) ([]byte, error) {
params := generalJwk{Id: jwk.Id()}
key := jwk.Public()
var val interface{}
switch v := key.(type) {
case *rsa.PublicKey:
val = makeRSAPublicJwk(v, params)
case *ecdsa.PublicKey:
val = makeECPublicJwk(v, params)
case ed25519.PublicKey:
val = makeOKPJwk(v, params)
case []byte:
val = makeOctetJwk(v, params)
default:
return nil, fmt.Errorf(`unable to marshal JWK: unrecognized public key type: %T`, key)
}
return json.Marshal(val)
}
func unmarshalJwk(data []byte) (Jwk, error) {
var meta generalJwk
if e := json.Unmarshal(data, &meta); e != nil {
return nil, e
}
var jwk publicJwk
switch meta.Type {
case JwkTypeRSA:
jwk = &rsaPublicJwk{}
case JwkTypeEC:
jwk = &ecPublicJwk{}
case JwkTypeOctet:
jwk = &octetJwk{}
case JwkTypeEdDSA:
jwk = &okpJwk{}
default:
return nil, fmt.Errorf(`invalid 'kty': %s`, meta.Type)
}
if e := json.Unmarshal(data, jwk); e != nil {
return nil, e
}
return jwk.toJwk()
}
type jwkBytes []byte
func (b jwkBytes) String() string {
return base64.RawURLEncoding.EncodeToString(b)
}
func (b jwkBytes) BigInt() *big.Int {
if len(b) == 0 {
return nil
}
i := big.NewInt(0)
i.SetBytes(b)
return i
}
func (b jwkBytes) MarshalText() ([]byte, error) {
return []byte(b.String()), nil
}
func (b *jwkBytes) UnmarshalText(data []byte) error {
v, e := base64.RawURLEncoding.DecodeString(string(data))
if e != nil {
return e
}
*b = v
return nil
}
type generalJwk struct {
Id string `json:"kid"`
Type string `json:"kty"`
}
type publicJwk interface {
toJwk() (Jwk, error)
}
// For EC key
// RFC 7518: https://datatracker.ietf.org/doc/html/rfc7518#section-6.2
type ecPublicJwk struct {
generalJwk
Curve string `json:"crv"`
CoordinateX jwkBytes `json:"x"`
CoordinateY jwkBytes `json:"y,omitempty"`
}
func (j ecPublicJwk) toJwk() (Jwk, error) {
var curve elliptic.Curve
switch j.Curve {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf(`unsupported 'crv' of EC JWK`)
}
key := &ecdsa.PublicKey{
Curve: curve,
X: j.CoordinateX.BigInt(),
Y: j.CoordinateY.BigInt(),
}
return NewJwk(j.Id, j.Id, key), nil
}
func makeECPublicJwk(key *ecdsa.PublicKey, params generalJwk) ecPublicJwk {
var x, y []byte
if key.X != nil {
x = key.X.Bytes()
}
if key.Y != nil {
y = key.Y.Bytes()
}
var crv string
if key.Curve.Params() != nil {
crv = key.Curve.Params().Name
}
params.Type = JwkTypeEC
return ecPublicJwk{
generalJwk: params,
Curve: crv,
CoordinateX: x,
CoordinateY: y,
}
}
// For RSA key
// RFC 7518: https://datatracker.ietf.org/doc/html/rfc7518#section-6.3
type rsaPublicJwk struct {
generalJwk
Modulus jwkBytes `json:"n"`
Exponent jwkBytes `json:"e"`
}
func (j rsaPublicJwk) toJwk() (Jwk, error) {
key := &rsa.PublicKey{
N: j.Modulus.BigInt(),
E: int(j.Exponent.BigInt().Uint64()),
}
return NewJwk(j.Id, j.Id, key), nil
}
func makeRSAPublicJwk(key *rsa.PublicKey, params generalJwk) rsaPublicJwk {
params.Type = JwkTypeRSA
return rsaPublicJwk{
generalJwk: params,
Modulus: key.N.Bytes(),
// Exponent convert to two's-complement in big-endian byte-order
Exponent: bigEndian(key.E),
}
}
// For symmetric key
// RFC 7518: https://datatracker.ietf.org/doc/html/rfc7518#section-6.4
type octetJwk struct {
generalJwk
Key jwkBytes `json:"k"`
}
func (j octetJwk) toJwk() (Jwk, error) {
return NewJwk(j.Id, j.Id, []byte(j.Key)), nil
}
func makeOctetJwk(key []byte, params generalJwk) octetJwk {
params.Type = JwkTypeOctet
return octetJwk{
generalJwk: params,
Key: key,
}
}
// For ed25519 key. "OKP" (Octet Public Pair)
// RFC 8037: https://datatracker.ietf.org/doc/html/rfc8037#section-2
type okpJwk struct {
generalJwk
Curve string `json:"crv"`
PublicKey jwkBytes `json:"x"`
}
func (j okpJwk) toJwk() (Jwk, error) {
return NewJwk(j.Id, j.Id, ed25519.PublicKey(j.PublicKey)), nil
}
func makeOKPJwk(key ed25519.PublicKey, params generalJwk) okpJwk {
params.Type = JwkTypeEdDSA
return okpJwk{
generalJwk: params,
Curve: "Ed25519",
PublicKey: jwkBytes(key),
}
}
func bigEndian(i int) []byte {
buf := bytes.NewBuffer(make([]byte, 0, 8))
if e := binary.Write(buf, binary.BigEndian, uint64(i)); e != nil {
return nil
}
// remove leading zeros
data := buf.Bytes()
for j := range data {
if data[j] != 0 {
data = data[j:]
break
}
}
return data
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/hmac"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/binary"
"encoding/hex"
"encoding/pem"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"hash"
)
const (
errTmplInvalidJwkName = `invalid JWK name`
errTmplPubPrivMixed = `found both public and private key block in same PEM file`
errTmplNoKeyFoundInPem = `PEM file doesn't includes any supported private nor public keys`
errTmplUnsupportedPubKey = `non-supported public key [%T] in certificate`
errTmplUnsupportedFile = `unrecognized crypto key file format [%s]`
errTmplUnsupportedBlock = `non-supported block [%T] in the file`
)
// FileJwkStore implements JwkStore and JwkRotator
// This store uses load key files for public and private keys.
// File locations and "kids" are read from properties. And rotate between pre-defined keys
// The properties are structured as follows:
//
// keys:
// my-key-name:
// id: my-key-id
// format: pem
// file: my-key-file.pem
//
// Keys loaded under the same key name will all have the same name. The LoadByName method will load one of the keys.
// Which key will be loaded is determined by the current index for that name. The Rotate method will increment the index
// for that name.
// If the pem file contains one key, the key id will be the same as the key name.
//
// If the pem file contains multiple keys, the following rules will be used to generate the key id:
//
// If id property is provided, the actual key id will be the property id plus an integer suffix.
// If id property is not provided, the actual key id will be generated based on elements of the public key. The ID value
// will be consistent across restarts.
//
// Supports PEM format.
// Supports:
// 1. PKCS8 unencrypted private key (rsa, ecdsa, ed25519)
// 2. traditional unencrypted private key and encrypted private key (rsa and ecdsa)
// 3. traditional public key (pkcs1 for rsa or pkix for rsa, PKIX for ecdsa and ed25519)
// 4. x509 certificate (rsa, ecdsa, ed25519)
// 5. HMAC key (using custom label "HMAC KEY", i.e. -----BEGIN HMAC KEY-----)
//
// Note that if HMAC is used, the application must be responsible for securing the jwks endpoint, or encrypt the jwks content.
// This is because HMAC keys are symmetric and should not be exposed to public. By default, the jwks endpoint is not secured.
type FileJwkStore struct {
cacheById map[string]Jwk
cacheByName map[string][]Jwk
indexes map[string]int
}
func NewFileJwkStore(props CryptoProperties) *FileJwkStore {
s := FileJwkStore{
cacheById: map[string]Jwk{},
cacheByName: map[string][]Jwk{},
indexes: map[string]int{},
}
// load files
for k, v := range props.Keys {
jwks, e := loadJwks(k, v)
// ignore unsupported keys
if e == nil {
for _, jwk := range jwks {
s.cacheById[jwk.Id()] = jwk
}
s.cacheByName[k] = jwks
s.indexes[k] = 0
} else {
logger.Warnf("ignored key %s due to error %v", k, e)
}
}
return &s
}
func (s *FileJwkStore) LoadByKid(_ context.Context, kid string) (Jwk, error) {
jwk, ok := s.cacheById[kid]
if !ok {
return nil, fmt.Errorf("cannot find JWK with kid [%s]", kid)
}
return jwk, nil
}
func (s *FileJwkStore) LoadByName(_ context.Context, name string) (Jwk, error) {
jwks, ok := s.cacheByName[name]
if !ok || len(jwks) == 0 {
return nil, fmt.Errorf("cannot find JWK with name [%s]", name)
}
i := s.indexes[name] % len(jwks)
return jwks[i], nil
}
func (s *FileJwkStore) LoadAll(_ context.Context, names ...string) ([]Jwk, error) {
jwks := make([]Jwk, 0, len(s.cacheById))
for k, v := range s.cacheByName {
match := len(names) == 0 // if names is empty, match all
for i := 0; !match && i < len(names); i++ {
match = names[i] == k
}
if !match {
continue
}
for _, jwk := range v {
jwks = append(jwks, jwk)
}
}
return jwks, nil
}
func (s *FileJwkStore) Rotate(_ context.Context, name string) error {
current, ok := s.indexes[name]
if !ok {
return fmt.Errorf(errTmplInvalidJwkName)
}
jwks, ok := s.cacheByName[name]
if !ok || len(jwks) == 0 {
return fmt.Errorf(errTmplInvalidJwkName)
}
s.indexes[name] = (current + 1) % len(jwks)
return nil
}
/*************************
Helpers
*************************/
func loadJwks(name string, props CryptoKeyProperties) ([]Jwk, error) {
switch props.Format() {
case KeyFileFormatPem:
return loadJwksFromPem(name, props)
default:
return nil, fmt.Errorf(errTmplUnsupportedFile, props.KeyFormat)
}
}
func loadJwksFromPem(name string, props CryptoKeyProperties) ([]Jwk, error) {
items, e := cryptoutils.LoadMultiBlockPem(props.Location, props.Password)
if e != nil {
return nil, fmt.Errorf("unable to load JWK [%s] - %v", name, e)
}
privJwks := make([]Jwk, 0)
pubJwks := make([]Jwk, 0)
for i, v := range items {
var privKey crypto.PrivateKey
var pubKey crypto.PublicKey
// get private or public key
var ok bool
if privKey, ok = v.(privateKey); ok {
// got private key, do nothing
} else if pubKey, ok = v.(publicKey); ok {
// got public key, do nothing
} else if _, ok = v.(*x509.Certificate); ok {
cert := v.(*x509.Certificate)
if pubKey, ok = cert.PublicKey.(publicKey); !ok {
return nil, fmt.Errorf(errTmplUnsupportedPubKey, cert.PublicKey)
}
} else if _, ok = v.(*pem.Block); ok {
switch v.(*pem.Block).Type {
case "HMAC KEY":
logger.Warnf("File contains HMAC keys, please make sure the jwks end point is secured")
privKey = v.(*pem.Block).Bytes
default:
return nil, fmt.Errorf(errTmplUnsupportedBlock, v)
}
} else {
return nil, fmt.Errorf(errTmplUnsupportedBlock, v)
}
// validate and create JWK
switch {
case privKey == nil && len(privJwks) != 0:
return nil, fmt.Errorf(errTmplPubPrivMixed)
case privKey == nil:
kid := calculateKid(props, name, i, len(items), pubKey)
pubJwks = append(pubJwks, NewJwk(kid, name, pubKey))
case len(pubJwks) != 0:
return nil, fmt.Errorf(errTmplPubPrivMixed)
default:
kid := calculateKid(props, name, i, len(items), privKey)
privJwks = append(privJwks, NewPrivateJwk(kid, name, privKey))
}
}
switch {
case len(pubJwks) == 0 && len(privJwks) == 0:
return nil, fmt.Errorf(errTmplNoKeyFoundInPem)
case len(pubJwks) != 0 && len(privJwks) != 0:
// this should not happen if previous logic (in loop) were correct
return nil, fmt.Errorf(errTmplPubPrivMixed)
case len(pubJwks) != 0:
return pubJwks, nil
case len(privJwks) != 0:
fallthrough
default:
return privJwks, nil
}
}
func calculateKid(props CryptoKeyProperties, name string, blockIndex int, numBlocks int, key any) string {
if numBlocks == 1 {
return name
}
if props.Id != "" {
return fmt.Sprintf("%s-%d", props.Id, blockIndex)
}
//best effort to generate a kid that is consistent across restarts
var hash hash.Hash
switch key.(type) {
case *rsa.PrivateKey:
privKey := key.(*rsa.PrivateKey)
hash = hashForRSA(&privKey.PublicKey)
case *rsa.PublicKey:
hash = hashForRSA(key.(*rsa.PublicKey))
case *ecdsa.PrivateKey:
privKey := key.(*ecdsa.PrivateKey)
hash = hashForEcdsa(privKey.Public().(*ecdsa.PublicKey))
case *ecdsa.PublicKey:
hash = hashForEcdsa(key.(*ecdsa.PublicKey))
case ed25519.PrivateKey:
privKey := key.(ed25519.PrivateKey)
hash = hashForEd25519(privKey.Public().(ed25519.PublicKey))
case ed25519.PublicKey:
hash = hashForEd25519(key.(ed25519.PublicKey))
case []byte:
hash = hmac.New(sha256.New, key.([]byte))
hash.Write([]byte(name))
}
sum := hash.Sum(nil)
suffix := hex.EncodeToString(sum)
return name + "-" + suffix
}
func hashForRSA(key *rsa.PublicKey) hash.Hash {
hash := sha256.New224()
_, _ = hash.Write(key.N.Bytes())
_ = binary.Write(hash, binary.LittleEndian, int64(key.E))
return hash
}
func hashForEd25519(key ed25519.PublicKey) hash.Hash {
hash := sha256.New224()
_, _ = hash.Write(key)
return hash
}
func hashForEcdsa(key *ecdsa.PublicKey) hash.Hash {
hash := sha256.New224()
_, _ = hash.Write(key.X.Bytes())
_, _ = hash.Write(key.Y.Bytes())
return hash
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"context"
"fmt"
"github.com/golang-jwt/jwt/v4"
"sync"
)
// SingleJwkStore implements JwkStore
// This store always returns single JWK if Kid matches, return error if not
// This store is majorly for testing
type SingleJwkStore struct {
initOnce sync.Once
Kid string
SigningMethod jwt.SigningMethod
jwk Jwk
}
func NewSingleJwkStoreWithOptions(opts ...func(s *SingleJwkStore)) *SingleJwkStore {
store := SingleJwkStore{
SigningMethod: jwt.SigningMethodRS256,
}
for _, fn := range opts {
fn(&store)
}
return &store
}
// NewSingleJwkStore
// Deprecated: Use NewSingleJwkStoreWithOptions
func NewSingleJwkStore(kid string) *SingleJwkStore {
return NewSingleJwkStoreWithOptions(func(s *SingleJwkStore) {
s.Kid = kid
})
}
func (s *SingleJwkStore) LoadByKid(_ context.Context, kid string) (Jwk, error) {
if e := s.LazyInit(); e != nil {
return nil, e
}
if s.Kid == kid {
return s.jwk, nil
}
return nil, fmt.Errorf("cannot find JWK with Kid [%s]", kid)
}
func (s *SingleJwkStore) LoadByName(_ context.Context, name string) (Jwk, error) {
if e := s.LazyInit(); e != nil {
return nil, e
}
if s.Kid == name {
return s.jwk, nil
}
return nil, fmt.Errorf("cannot find JWK with name [%s]", name)
}
func (s *SingleJwkStore) LoadAll(_ context.Context, _ ...string) ([]Jwk, error) {
if e := s.LazyInit(); e != nil {
return nil, e
}
return []Jwk{s.jwk}, nil
}
func (s *SingleJwkStore) LazyInit() (err error) {
s.initOnce.Do(func() {
s.jwk, err = generateRandomJwk(s.SigningMethod, s.Kid, s.Kid)
if err != nil {
return
}
})
return
}
package jwt
import (
"context"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/cacheutils"
"net/http"
"time"
)
var ckJwkSet = cacheutils.StringKey(`github.com/cisco-open/go-lanai/JWKSet`)
type jwkSet struct {
Keys []*GenericJwk `json:"keys"`
}
type RemoteJwkOptions func(cfg *RemoteJwkConfig)
type RemoteJwkConfig struct {
// HttpClient the underlying http.Client to use. Default: http.DefaultClient
HttpClient *http.Client
// JwkSetURL the URL of JWKSet endpoint for getting all JWKs. Default: "http://localhost:8900/auth/v2/jwks"
// e.g. http://localhost:8900/auth/v2/jwks
JwkSetURL string
// JwkBaseURL the base URL of the endpoint for getting JWK by kid (without tailing slash). The actual URL would be "JwkBaseURL/<kid>".
// (Optional) When not set (empty string), the JwkSetURL is used. Default: "http://localhost:8900/auth/v2/jwks"
// e.g. JwkBaseURL = "http://localhost:8900/auth/v2/jwks", actual URL is "http://localhost:8900/auth/v2/jwks/<kid>"
JwkBaseURL string
// JwkSetRequestFunc a function that create http.Request for JWKSet endpoint. When set, override JwkSetURL.
// (Optional) When not set, JwkSetURL is used with GET method.
JwkSetRequestFunc func(ctx context.Context) *http.Request
// JwkRequestFunc a function that create http.Request for "get JWK by kid". When set, override JwkBaseURL.
// (Optional) When not set, JwkBaseURL is used with GET method. If JwkBaseURL is not set either, JWKSet endpoint is used.
JwkRequestFunc func(ctx context.Context, kid string) *http.Request
// DisableCache disable internal caching. If the cache is disabled, the store would invoke an external HTTP transaction
// everytime when any of store's method is called. Default: false
DisableCache bool
// TTL cache setting. TTL controls how long the HTTP result is kept in cache.
TTL time.Duration
// RetryBackoff cache setting. It controls how long to wait between failed HTTP retries.
RetryBackoff time.Duration
// Retry cache setting. It controls how many times the cache would retry for failed HTTP transaction.
Retry int
}
// NewRemoteJwkStore creates a JwkStore that load JWK with public key from an external JWKSet endpoint.
// Note: Use RemoteJwkStore with JwtDecoder ONLY.
//
// RemoteJwkStore is not capable of decrypt private key from JWK response.
//
// See RemoteJwkStore for more details
func NewRemoteJwkStore(opts ...RemoteJwkOptions) *RemoteJwkStore {
store := RemoteJwkStore{
RemoteJwkConfig: RemoteJwkConfig{
HttpClient: http.DefaultClient,
JwkSetURL: "http://localhost:8900/auth/v2/jwks",
TTL: 60 * time.Minute,
RetryBackoff: 2 * time.Second,
Retry: 2,
},
}
for _, fn := range opts {
fn(&store.RemoteJwkConfig)
}
if !store.DisableCache {
store.cache = cacheutils.NewMemCache(func(opt *cacheutils.CacheOption) {
opt.Heartbeat = store.TTL
opt.LoadRetry = store.Retry
})
}
if store.JwkSetRequestFunc == nil {
store.JwkSetRequestFunc = remoteJwkSetRequestFuncWithUrl(store.JwkSetURL)
}
if store.JwkRequestFunc == nil && len(store.JwkBaseURL) != 0 {
store.JwkRequestFunc = remoteJwkRequestFuncWithUrl(store.JwkBaseURL)
}
return &store
}
// RemoteJwkStore implements JwkStore and load JWK with public key from an external JWKSet endpoint.
// Important: Use RemoteJwkStore with JwtDecoder ONLY.
//
// RemoteJwkStore is not capable of decrypt private key from JWK response
//
// Note: LoadByName and LoadAll would treat Jwk's "name" as "kid". Because "name" is introduced for managing
//
// key rotation, which is not applicable to JwtDecoder: JwtDecoder strictly use `kid` if present in header
// or default "name" (in such case, should be hard coded globally known "kid")
type RemoteJwkStore struct {
RemoteJwkConfig
cache cacheutils.MemCache
}
func (s *RemoteJwkStore) LoadByKid(ctx context.Context, kid string) (Jwk, error) {
if s.DisableCache || s.cache == nil {
return s.fetchJwkByKid(ctx, kid)
}
i, e := s.cache.GetOrLoad(ctx, cacheutils.StringKey(kid), s.loadJwkByKid, nil)
if e != nil {
return nil, e
}
return i.(Jwk), nil
}
func (s *RemoteJwkStore) LoadByName(ctx context.Context, name string) (Jwk, error) {
// Note: remote JWK endpoint doesn't give name, we treat name as KID
if s.DisableCache || s.cache == nil {
return s.fetchJwkByKid(ctx, name)
}
i, e := s.cache.GetOrLoad(ctx, cacheutils.StringKey(name), s.loadJwkByKid, nil)
if e != nil {
return nil, e
}
return i.(Jwk), nil
}
func (s *RemoteJwkStore) LoadAll(ctx context.Context, names ...string) ([]Jwk, error) {
var loaded interface{}
var err error
if s.DisableCache || s.cache == nil {
loaded, err = s.fetchJwkSet(ctx)
} else {
loaded, err = s.cache.GetOrLoad(ctx, ckJwkSet, s.loadJwkSet, nil)
}
if err != nil {
return nil, err
}
return s.filterJwkSet(loaded.([]Jwk), names...), nil
}
func (s *RemoteJwkStore) loadJwkByKid(ctx context.Context, k cacheutils.Key) (v interface{}, exp time.Time, err error) {
key := k.(cacheutils.StringKey)
jwk, e := s.fetchJwkByKid(ctx, string(key))
if e != nil {
return nil, time.Now().Add(s.RetryBackoff), e
}
return jwk, time.Now().Add(s.TTL), nil
}
func (s *RemoteJwkStore) loadJwkSet(ctx context.Context, _ cacheutils.Key) (v interface{}, exp time.Time, err error) {
jwks, e := s.fetchJwkSet(ctx)
if e != nil {
return nil, time.Now().Add(s.RetryBackoff), e
}
return jwks, time.Now().Add(s.TTL), nil
}
func (s *RemoteJwkStore) fetchJwkByKid(ctx context.Context, kid string) (Jwk, error) {
if s.JwkRequestFunc == nil {
// JWK by kid is not available, use JWKSet endpoint
jwks, e := s.fetchJwkSet(ctx)
if e != nil {
return nil, e
}
for _, jwk := range jwks {
if kid == jwk.Id() {
return jwk, nil
}
}
return nil, fmt.Errorf(`failed to fetch JWK with kid [%s]: kid does not exist`, kid)
}
req := s.JwkRequestFunc(ctx, kid)
if req == nil {
return nil, fmt.Errorf(`unable to resolve HTTP request for JWK with kid [%s]`, kid)
}
resp, e := s.doFetch(req)
if e != nil {
return nil, fmt.Errorf(`failed to fetch JWK with kid [%s]: %v`, kid, e)
}
defer func() { _ = resp.Body.Close() }()
var jwk GenericJwk
if e := json.NewDecoder(resp.Body).Decode(&jwk); e != nil {
return nil, fmt.Errorf(`unable to parse JWK from JSON: %v`, e)
}
return &jwk, nil
}
func (s *RemoteJwkStore) fetchJwkSet(ctx context.Context) ([]Jwk, error) {
req := s.JwkSetRequestFunc(ctx)
if req == nil {
return nil, fmt.Errorf(`unable to resolve HTTP request for JWK Set`)
}
resp, e := s.doFetch(req)
if e != nil {
return nil, fmt.Errorf(`failed to fetch JWK Set: %v`, e)
}
defer func() { _ = resp.Body.Close() }()
var jwkSet jwkSet
if e := json.NewDecoder(resp.Body).Decode(&jwkSet); e != nil {
return nil, fmt.Errorf(`unable to parse JWK Set from JSON: %v`, e)
}
jwks := make([]Jwk, len(jwkSet.Keys))
for i := range jwkSet.Keys {
jwks[i] = jwkSet.Keys[i]
}
return jwks, nil
}
func (s *RemoteJwkStore) doFetch(req *http.Request) (*http.Response, error) {
switch resp, e := s.HttpClient.Do(req); {
case e != nil:
return nil, e
case resp.StatusCode < 200 || resp.StatusCode >= 300:
return nil, fmt.Errorf(`failed with status code [%d: %s]`, resp.StatusCode, resp.Status)
default:
return resp, nil
}
}
func (s *RemoteJwkStore) filterJwkSet(jwks []Jwk, kids ...string) []Jwk {
if len(kids) == 0 {
return jwks
}
filtered := make([]Jwk, 0, len(jwks))
for i := range jwks {
for _, kid := range kids {
if jwks[i].Id() == kid {
filtered = append(filtered, jwks[i])
break
}
}
}
return filtered
}
func remoteJwkSetRequestFuncWithUrl(base string) func(ctx context.Context) *http.Request {
return func(ctx context.Context) *http.Request {
req, e := http.NewRequestWithContext(ctx, http.MethodGet, base, nil)
if e != nil {
return nil
}
return req
}
}
func remoteJwkRequestFuncWithUrl(base string) func(ctx context.Context, kid string) *http.Request {
return func(ctx context.Context, kid string) *http.Request {
req, e := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf(`%s/%s`, base, kid), nil)
if e != nil {
return nil
}
return req
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"context"
"fmt"
"github.com/golang-jwt/jwt/v4"
)
var (
kidRoundRobin = []string{"kid1", "kid2", "kid3"}
)
// StaticJwkStore implements JwkStore and JwkRotator
// This store uses "kid" as seed to generate PrivateJwk. For same "kid" the returned key is same.
// this one is not thread safe
type StaticJwkStore struct {
KIDs []string
SigningMethod jwt.SigningMethod
current int
lookup map[string]Jwk
}
func NewStaticJwkStoreWithOptions(opts ...func(s *StaticJwkStore)) *StaticJwkStore {
store := StaticJwkStore{
KIDs: kidRoundRobin,
SigningMethod: jwt.SigningMethodRS256,
lookup: map[string]Jwk{},
}
for _, fn := range opts {
fn(&store)
}
return &store
}
// NewStaticJwkStore
// Deprecated: Use NewStaticJwkStoreWithOptions
func NewStaticJwkStore(kids ...string) *StaticJwkStore {
return NewStaticJwkStoreWithOptions(func(s *StaticJwkStore) {
if len(kids) != 0 {
s.KIDs = kids
}
})
}
func (s *StaticJwkStore) Rotate(_ context.Context, _ string) error {
s.current = (s.current + 1) % len(s.KIDs)
return nil
}
func (s *StaticJwkStore) LoadByKid(_ context.Context, kid string) (Jwk, error) {
for i := range s.KIDs {
if s.KIDs[i] == kid {
return s.getOrNew(kid)
}
}
return nil, fmt.Errorf("JWK with name '%s' is not available", kid)
}
func (s *StaticJwkStore) LoadByName(_ context.Context, _ string) (Jwk, error) {
return s.getOrNew(s.KIDs[s.current])
}
func (s *StaticJwkStore) LoadAll(_ context.Context, _ ...string) ([]Jwk, error) {
jwks := make([]Jwk, len(s.KIDs))
i := 0
for _, v := range s.KIDs {
jwks[i], _ = s.getOrNew(v)
i++
}
return jwks, nil
}
func (s *StaticJwkStore) getOrNew(kid string) (Jwk, error) {
if jwk, ok := s.lookup[kid]; ok {
return jwk, nil
}
jwk, e := generateRandomJwk(s.SigningMethod, kid, kid)
if e != nil {
return nil, e
}
s.lookup[kid] = jwk
return jwk, nil
}
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"fmt"
"github.com/golang-jwt/jwt/v4"
)
func resolveSigningMethod(key crypto.PrivateKey) (jwt.SigningMethod, error) {
switch v := key.(type) {
case *rsa.PrivateKey:
return jwt.SigningMethodRS256, nil
case *ecdsa.PrivateKey:
size := v.Curve.Params().BitSize
switch {
case size >= 521:
return jwt.SigningMethodES512, nil
case size >= 384:
return jwt.SigningMethodES384, nil
case size >= 256:
return jwt.SigningMethodES256, nil
default:
return nil, fmt.Errorf(`invalid ECDSA private key. Expect P-256 or more, but got %s`, v.Curve.Params().Name)
}
case ed25519.PrivateKey:
return jwt.SigningMethodEdDSA, nil
case []byte:
switch {
case len(v) >= 512/8:
return jwt.SigningMethodHS512, nil
case len(v) >= 384/8:
return jwt.SigningMethodHS384, nil
case len(v) >= 256/8:
return jwt.SigningMethodHS256, nil
default:
return nil, fmt.Errorf(`invalid MAC secret. Expect 256B or more, but got %dB`, len(v))
}
default:
// Note: *ecdh.PrivateKey is not supported by github.com/golang-jwt/jwt/v4
return nil, fmt.Errorf(`unable to find proper signing method: unrecognized private key type: %T`, key)
}
}
func generateCompatiblePrivateKey(method jwt.SigningMethod) (crypto.PrivateKey, error) {
switch method {
case jwt.SigningMethodRS256, jwt.SigningMethodRS384, jwt.SigningMethodRS512:
return rsa.GenerateKey(rand.Reader, rsaKeySize)
case jwt.SigningMethodES256:
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case jwt.SigningMethodES384:
return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case jwt.SigningMethodES512:
return ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
case jwt.SigningMethodPS256, jwt.SigningMethodPS384, jwt.SigningMethodPS512:
return rsa.GenerateKey(rand.Reader, rsaKeySize)
case jwt.SigningMethodEdDSA:
_, priv, e := ed25519.GenerateKey(rand.Reader)
return priv, e
case jwt.SigningMethodHS256, jwt.SigningMethodHS384, jwt.SigningMethodHS512:
// RFC7518: When using HMAC
// "key of the same size as the hash output (for instance, 256 bits for "HS256") or larger MUST be used with this algorithm."
var secret []byte
switch method {
case jwt.SigningMethodHS256:
secret = make([]byte, 256/8)
case jwt.SigningMethodHS384:
secret = make([]byte, 384/8)
case jwt.SigningMethodHS512:
secret = make([]byte, 512/8)
}
if _, e := rand.Reader.Read(secret); e != nil {
return nil, e
}
return secret, nil
default:
return nil, fmt.Errorf(`unsupported signing method: %T`, method)
}
}
func generateRandomJwk(method jwt.SigningMethod, kid, name string) (Jwk, error) {
privKey, e := generateCompatiblePrivateKey(method)
if e != nil {
return nil, e
}
return NewPrivateJwk(kid, name, privKey), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jwt
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/pkg/errors"
"strings"
)
var logger = log.New("OAuth2.JWT")
/***********************
Crypto
************************/
const CryptoKeysPropertiesPrefix = "security"
const (
KeyFileFormatPem KeyFormatType = "pem"
)
type KeyFormatType string
type CryptoProperties struct {
Keys map[string]CryptoKeyProperties `json:"keys"`
Jwt JwtProperties `json:"jwt"`
}
type JwtProperties struct {
KeyName string `json:"key-name"`
}
type CryptoKeyProperties struct {
Id string `json:"id"`
KeyFormat string `json:"format"`
Location string `json:"file"`
Password string `json:"password"`
}
func (p CryptoKeyProperties) Format() KeyFormatType {
return KeyFormatType(strings.ToLower(p.KeyFormat))
}
// NewCryptoProperties create a CryptoProperties with default values
func NewCryptoProperties() *CryptoProperties {
return &CryptoProperties {
Keys: map[string]CryptoKeyProperties{},
}
}
// BindCryptoProperties create and bind CryptoProperties, with a optional prefix
func BindCryptoProperties(ctx *bootstrap.ApplicationContext) CryptoProperties {
props := NewCryptoProperties()
if err := ctx.Config().Bind(props, CryptoKeysPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind CryptoProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"encoding/gob"
"github.com/cisco-open/go-lanai/pkg/log"
)
var logger = log.New("OAuth2")
func init() {
gob.Register((*authentication)(nil))
gob.Register((*OAuth2Request)(nil))
gob.Register((*OAuth2Error)(nil))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package timeoutsupport
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"go.uber.org/fx"
)
//var logger = log.New("SEC.Timeout")
var Module = &bootstrap.Module{
Name: "timeout",
Precedence: security.MinSecurityPrecedence + 10, //same as session. since this package doesn't invoke anything, the precedence has no real effect
Options: []fx.Option{
fx.Provide(security.BindTimeoutSupportProperties),
fx.Provide(provideTimeoutSupport),
},
}
type timeoutDI struct {
}
func provideTimeoutSupport(ctx *bootstrap.ApplicationContext, cf redis.ClientFactory, prop security.TimeoutSupportProperties) oauth2.TimeoutApplier {
client, err := cf.New(ctx, func(opt *redis.ClientOption) {
opt.DbIndex = prop.DbIndex
})
if err != nil {
panic(err)
}
support := NewRedisTimeoutApplier(client)
return support
}
func Use() {
bootstrap.Register(Module)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package timeoutsupport
import (
"context"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security/session/common"
"strconv"
"time"
)
type RedisTimeoutApplier struct {
sessionName string
client redis.Client
}
func NewRedisTimeoutApplier(client redis.Client) *RedisTimeoutApplier {
return &RedisTimeoutApplier{
sessionName: common.DefaultName,
client: client,
}
}
func (r *RedisTimeoutApplier) ApplyTimeout(ctx context.Context, sessionId string) (valid bool, err error) {
key := common.GetRedisSessionKey(r.sessionName, sessionId)
//check if session exists
existCmd := r.client.Exists(ctx, key)
if existCmd.Err() != nil {
valid = false
err = existCmd.Err()
return
} else {
valid = existCmd.Val() == 1
}
if !valid {
return
}
hmGetCmd := r.client.HMGet(ctx, key, common.SessionIdleTimeoutDuration, common.SessionAbsTimeoutTime)
if hmGetCmd.Err() != nil {
err = hmGetCmd.Err()
return
}
result, _ := hmGetCmd.Result()
var timeoutSetting common.TimeoutSetting = 0
var idleExpiration, absExpiration time.Time
now := time.Now()
if result[0] != nil {
idleTimeout, e := time.ParseDuration(result[0].(string))
if e != nil {
err = e
return
}
idleExpiration = now.Add(idleTimeout)
timeoutSetting = timeoutSetting | common.IdleTimeoutEnabled
}
if result[1] != nil {
absTimeoutUnixTime, e := strconv.ParseInt(result[1].(string), 10, 0)
if e != nil {
err = e
return
}
absExpiration = time.Unix(absTimeoutUnixTime, 0)
timeoutSetting = timeoutSetting | common.AbsoluteTimeoutEnabled
}
canExpire, expiration := common.CalculateExpiration(timeoutSetting, idleExpiration, absExpiration)
//update session last accessed time
hsetCmd := r.client.HSet(ctx, key, common.SessionLastAccessedField, now.Unix())
if hsetCmd.Err() != nil {
err = hsetCmd.Err()
return
}
if canExpire {
expireCmd := r.client.ExpireAt(ctx, key, expiration)
err = expireCmd.Err()
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tokenauth
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
)
/************************
Access Control
************************/
func ScopesApproved(scopes...string) access.ControlFunc {
if len(scopes) == 0 {
return func(_ security.Authentication) (bool, error) {
return true, nil
}
}
return func(auth security.Authentication) (decision bool, reason error) {
err := security.NewAccessDeniedError("required scope was not approved by user")
switch oauth := auth.(type) {
case oauth2.Authentication:
if oauth.OAuth2Request() == nil || !oauth.OAuth2Request().Approved() {
return false, err
}
approved := oauth.OAuth2Request().Scopes()
if approved == nil || !approved.HasAll(scopes...) {
return false, err
}
default:
return false, err
}
return true, nil
}
}
/******************************
Access Control Conditions
*******************************/
// RequireScopes returns ControlCondition using ScopesApproved
func RequireScopes(scopes ...string) access.ControlCondition {
return &access.ConditionWithControlFunc{
Description: fmt.Sprintf("client has scopes [%s] approved", scopes),
ControlFunc: ScopesApproved(scopes...),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tokenauth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"reflect"
)
/******************************
security.Authenticator
******************************/
type Authenticator struct {
tokenStoreReader oauth2.TokenStoreReader
}
type AuthenticatorOptions func(opt *AuthenticatorOption)
type AuthenticatorOption struct {
TokenStoreReader oauth2.TokenStoreReader
}
func NewAuthenticator(options ...AuthenticatorOptions) *Authenticator {
opt := AuthenticatorOption{}
for _, f := range options {
if f != nil {
f(&opt)
}
}
return &Authenticator{
tokenStoreReader: opt.TokenStoreReader,
}
}
func (a *Authenticator) Authenticate(ctx context.Context, candidate security.Candidate) (security.Authentication, error) {
can, ok := candidate.(*BearerToken)
if !ok {
return nil, nil
}
// TODO add remote check_token endpoint support
auth, e := a.tokenStoreReader.ReadAuthentication(ctx, can.Token, oauth2.TokenHintAccessToken)
if e != nil {
return nil, e
}
// perform some checks
switch {
case auth.State() < security.StateAuthenticated:
return nil, oauth2.NewInvalidAccessTokenError("token is not associated with an authenticated session")
case auth.OAuth2Request().ClientId() == "":
return nil, oauth2.NewInvalidAccessTokenError("token is not issued to a valid client")
case auth.UserAuthentication() != nil && reflect.ValueOf(auth.UserAuthentication().Principal()).IsZero():
return nil, oauth2.NewInvalidAccessTokenError("token is not authorized by a valid user")
}
return auth, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tokenauth
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
var (
FeatureId = security.FeatureId("OAuth2TokenAuth", security.FeatureOrderOAuth2TokenAuth)
)
//goland:noinspection GoNameStartsWithPackageName
type TokenAuthConfigurer struct {
tokenStoreReader oauth2.TokenStoreReader
}
//goland:noinspection GoNameStartsWithPackageName
type TokenAuthOptions func(opt *TokenAuthOption)
//goland:noinspection GoNameStartsWithPackageName
type TokenAuthOption struct {
TokenStoreReader oauth2.TokenStoreReader
}
func NewTokenAuthConfigurer(opts ...TokenAuthOptions) *TokenAuthConfigurer {
opt := TokenAuthOption{}
for _, f := range opts {
f(&opt)
}
return &TokenAuthConfigurer{
tokenStoreReader: opt.TokenStoreReader,
}
}
func (c *TokenAuthConfigurer) Apply(feature security.Feature, ws security.WebSecurity) (err error) {
// Verify
f := feature.(*TokenAuthFeature)
if err := c.validate(f, ws); err != nil {
return err
}
// configure other features
errorhandling.Configure(ws).
AdditionalErrorHandler(f.errorHandler)
// use ScopesApproved(...) for scope based access decision maker
// setup authenticator
authenticator := NewAuthenticator(func(opt *AuthenticatorOption) {
opt.TokenStoreReader = c.tokenStoreReader
})
ws.Authenticator().(*security.CompositeAuthenticator).Add(authenticator)
// prepare middlewares
successHandler, ok := ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(security.AuthenticationSuccessHandler)
if !ok {
successHandler = security.NewAuthenticationSuccessHandler()
}
mw := NewTokenAuthMiddleware(func(opt *TokenAuthMWOption) {
opt.Authenticator = ws.Authenticator()
opt.SuccessHandler = successHandler
opt.PostBodyEnabled = f.postBodyEnabled
})
// install middlewares
tokenAuth := middleware.NewBuilder("token authentication").
Order(security.MWOrderOAuth2TokenAuth).
Use(mw.AuthenticateHandlerFunc())
ws.Add(tokenAuth)
return nil
}
func (c *TokenAuthConfigurer) validate(f *TokenAuthFeature, _ security.WebSecurity) error {
if c.tokenStoreReader == nil {
return fmt.Errorf("token store reader is not pre-configured")
}
if f.errorHandler == nil {
f.errorHandler = NewOAuth2ErrorHanlder()
}
//if f.granters == nil || len(f.granters) == 0 {
// return fmt.Errorf("token granters is not set")
//}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tokenauth
/************************
security.Candidate
************************/
// BearerToken is the supported security.Candidate of resource server authenticator
type BearerToken struct {
Token string
DetailsMap map[string]interface{}
}
func (t *BearerToken) Principal() interface{} {
return ""
}
func (t *BearerToken) Credentials() interface{} {
return t.Token
}
func (t *BearerToken) Details() interface{} {
return t.DetailsMap
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tokenauth
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"net/http"
)
// OAuth2ErrorHandler implements security.ErrorHandler
// It's responsible to handle all oauth2 errors
type OAuth2ErrorHandler struct {}
func NewOAuth2ErrorHanlder() *OAuth2ErrorHandler {
return &OAuth2ErrorHandler{}
}
// HandleError implements security.ErrorHandler
func (h *OAuth2ErrorHandler) HandleError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
h.handleError(c, r, rw, err)
}
func (h *OAuth2ErrorHandler) handleError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
var oe oauth2.OAuth2ErrorTranslator
switch ok := errors.As(err, &oe); {
case ok && errors.Is(err, oauth2.ErrorTypeOAuth2):
writeOAuth2Error(c, r, rw, oe)
}
}
func writeOAuth2Error(c context.Context, r *http.Request, rw http.ResponseWriter, err oauth2.OAuth2ErrorTranslator) {
challenge := ""
sc := err.TranslateStatusCode()
if sc == http.StatusUnauthorized || sc == http.StatusForbidden {
challenge = fmt.Sprintf("%s %s", "Bearer", err.Error())
}
writeAdditionalHeader(c, r, rw, challenge)
security.WriteError(c, r, rw, sc, err)
}
func writeAdditionalHeader(_ context.Context, _ *http.Request, rw http.ResponseWriter, challenge string) {
if security.IsResponseWritten(rw) {
return
}
rw.Header().Add("Cache-Control", "no-store")
rw.Header().Add("Pragma", "no-cache");
if challenge != "" {
rw.Header().Set("WWW-Authenticate", challenge);
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tokenauth
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
)
//goland:noinspection GoNameStartsWithPackageName
type TokenAuthFeature struct {
errorHandler *OAuth2ErrorHandler
postBodyEnabled bool
}
func (f *TokenAuthFeature) Identifier() security.FeatureIdentifier {
return FeatureId
}
// Configure Standard security.Feature entrypoint
// use (*access.AccessControl).AllowIf(ScopesApproved(...)) for scope based access decision maker
func Configure(ws security.WebSecurity) *TokenAuthFeature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*TokenAuthFeature)
}
panic(fmt.Errorf("unable to configure oauth2 authserver: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// New Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
// use (*access.AccessControl).AllowIf(ScopesApproved(...)) for scope based access decision maker
func New() *TokenAuthFeature {
return &TokenAuthFeature{}
}
/** Setters **/
func (f *TokenAuthFeature) ErrorHandler(errorHandler *OAuth2ErrorHandler) *TokenAuthFeature {
f.errorHandler = errorHandler
return f
}
func (f *TokenAuthFeature) EnablePostBody() *TokenAuthFeature {
f.postBodyEnabled = true
return f
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tokenauth
import (
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/gin-gonic/gin"
"strings"
)
const (
bearerTokenPrefix = "Bearer "
)
/****************************
Token Authentication
****************************/
//goland:noinspection GoNameStartsWithPackageName
type TokenAuthMiddleware struct {
authenticator security.Authenticator
successHandler security.AuthenticationSuccessHandler
postBodyEnabled bool
}
//goland:noinspection GoNameStartsWithPackageName
type TokenAuthMWOptions func(opt *TokenAuthMWOption)
//goland:noinspection GoNameStartsWithPackageName
type TokenAuthMWOption struct {
Authenticator security.Authenticator
SuccessHandler security.AuthenticationSuccessHandler
PostBodyEnabled bool
}
func NewTokenAuthMiddleware(opts ...TokenAuthMWOptions) *TokenAuthMiddleware {
opt := TokenAuthMWOption{}
for _, optFunc := range opts {
if optFunc != nil {
optFunc(&opt)
}
}
return &TokenAuthMiddleware{
authenticator: opt.Authenticator,
successHandler: opt.SuccessHandler,
postBodyEnabled: opt.PostBodyEnabled,
}
}
func (mw *TokenAuthMiddleware) AuthenticateHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
// We always re-authenticate by clearing current auth
before := security.Get(ctx)
security.MustClear(ctx)
// grab bearer token and create candidate
tokenValue, e := mw.extractAccessToken(ctx)
if e != nil {
mw.handleError(ctx, e)
return
} else if tokenValue == "" {
// token is not present, we continue the MW chain
return
}
candidate := BearerToken{
Token: tokenValue,
DetailsMap: map[string]interface{}{},
}
// Authenticate
auth, err := mw.authenticator.Authenticate(ctx, &candidate)
if err != nil {
mw.handleError(ctx, err)
return
}
mw.handleSuccess(ctx, before, auth)
}
}
func (mw *TokenAuthMiddleware) handleSuccess(c *gin.Context, before, new security.Authentication) {
if new != nil {
security.MustSet(c, new)
}
mw.successHandler.HandleAuthenticationSuccess(c, c.Request, c.Writer, before, new)
// we don't explicitly write any thing on success
}
func (mw *TokenAuthMiddleware) extractAccessToken(ctx *gin.Context) (ret string, err error) {
header := ctx.GetHeader("Authorization")
if header == "" {
if mw.postBodyEnabled {
ret = ctx.PostForm(oauth2.ParameterAccessToken)
}
return
}
if !strings.HasPrefix(strings.ToUpper(header), strings.ToUpper(bearerTokenPrefix)) {
return "", oauth2.NewInvalidAccessTokenError("missing bearer token")
}
return header[len(bearerTokenPrefix):], nil
}
func (mw *TokenAuthMiddleware) handleError(c *gin.Context, err error) {
if !errors.Is(err, oauth2.ErrorTypeOAuth2) {
err = oauth2.NewInvalidAccessTokenError(err)
}
security.MustClear(c)
_ = c.Error(err)
c.Abort()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tokenauth
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
//var logger = log.New("OAuth2.Token")
//goland:noinspection GoNameStartsWithPackageName
var Module = &bootstrap.Module{
Name: "oauth2 resource server",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{},
}
func init() {
bootstrap.Register(Module)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/template"
"go.uber.org/fx"
)
var logger = log.New("Security")
var Module = &bootstrap.Module{
Name: "security",
Precedence: MaxSecurityPrecedence,
Options: []fx.Option{
fx.Provide(provideSecurityInitialization),
fx.Invoke(initialize),
},
}
// Use Maker func, does nothing. Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
template.RegisterGlobalModelValuer(template.ModelKeySecurity, template.ContextModelValuer(Get))
}
/**************************
Provider
***************************/
type dependencies struct {
fx.In
GlobalAuthenticator Authenticator `optional:"true"`
// may be generic security properties
}
type global struct {
fx.Out
Initializer Initializer
Registrar Registrar
}
// We let configurer.initializer can be autowired as both Initializer and Registrar
func provideSecurityInitialization(di dependencies) global {
initializer := newSecurity(di.GlobalAuthenticator)
return global{
Initializer: initializer,
Registrar: initializer,
}
}
/**************************
Initialize
***************************/
type initDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
Registerer *web.Registrar `optional:"true"`
Initializer Initializer
}
func initialize(lc fx.Lifecycle, di initDI) {
if err := di.Initializer.Initialize(di.AppContext, lc, di.Registerer); err != nil {
panic(err)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"sort"
"time"
)
/******************************
security.Authenticator
******************************/
type Authenticator struct {
accountStore security.AccountStore
passwdEncoder PasswordEncoder
otpManager OTPManager
mfaEventListeners []MFAEventListenerFunc
checkers []AuthenticationDecisionMaker
postProcessors []PostAuthenticationProcessor
}
type AuthenticatorOptionsFunc func(*AuthenticatorOptions)
type AuthenticatorOptions struct {
AccountStore security.AccountStore
PasswordEncoder PasswordEncoder
OTPManager OTPManager
MFAEventListeners []MFAEventListenerFunc
Checkers []AuthenticationDecisionMaker
PostProcessors []PostAuthenticationProcessor
}
func NewAuthenticator(optionFuncs...AuthenticatorOptionsFunc) *Authenticator {
options := AuthenticatorOptions {
PasswordEncoder: NewNoopPasswordEncoder(),
MFAEventListeners: []MFAEventListenerFunc{},
}
for _,optFunc := range optionFuncs {
if optFunc != nil {
optFunc(&options)
}
}
sort.SliceStable(options.Checkers, func(i,j int) bool {
return order.OrderedFirstCompare(options.Checkers[i], options.Checkers[j])
})
sort.SliceStable(options.PostProcessors, func(i,j int) bool {
return order.OrderedFirstCompareReverse(options.PostProcessors[i], options.PostProcessors[j])
})
return &Authenticator{
accountStore: options.AccountStore,
passwdEncoder: options.PasswordEncoder,
otpManager: options.OTPManager,
mfaEventListeners: options.MFAEventListeners,
checkers: options.Checkers,
postProcessors: options.PostProcessors,
}
}
func (a *Authenticator) Authenticate(ctx context.Context, candidate security.Candidate) (auth security.Authentication, err error) {
upp, ok := candidate.(*UsernamePasswordPair)
if !ok {
return nil, nil
}
// schedule post processing
ctx = utils.MakeMutableContext(ctx) //nolint:contextcheck
var user security.Account
defer func() {
auth, err = applyPostAuthenticationProcessors(a.postProcessors, ctx, user, candidate, auth, err)
}()
// Search user in the slice of allowed credentials
user, e := a.accountStore.LoadAccountByUsername(ctx, upp.Username)
if e != nil {
err = security.NewUsernameNotFoundError(MessageUserNotFound, e)
return
}
// pre checks
if e := makeDecision(a.checkers, ctx, upp, user, nil); e != nil {
err = a.translate(e)
return
}
// Check password
if password, ok := user.Credentials().(string);
!ok || upp.Username != user.Username() || !a.passwdEncoder.Matches(upp.Password, password) {
err = security.NewBadCredentialsError(MessageBadCredential)
return
}
// create authentication
newAuth, e := a.CreateSuccessAuthentication(upp, user)
if e != nil {
err = a.translate(e)
return
}
// post checks
if e := makeDecision(a.checkers, ctx, upp, user, newAuth); e != nil {
err = a.translate(e)
return
}
auth = newAuth
return
}
// CreateSuccessAuthentication exported for override posibility
func (a *Authenticator) CreateSuccessAuthentication(candidate *UsernamePasswordPair, account security.Account) (security.Authentication, error) {
details := candidate.DetailsMap
if details == nil {
details = map[string]interface{}{}
}
permissions := map[string]interface{}{}
// MFA support
if candidate.EnforceMFA == MFAModeMust || candidate.EnforceMFA != MFAModeSkip && account.UseMFA() {
// MFA required
if a.otpManager == nil {
return nil, security.NewInternalAuthenticationError(MessageOtpNotAvailable)
}
otp, err := a.otpManager.New()
if err != nil {
return nil, security.NewInternalAuthenticationError(MessageOtpNotAvailable)
}
permissions[SpecialPermissionMFAPending] = true
permissions[SpecialPermissionOtpId] = otp.ID()
broadcastMFAEvent(MFAEventOtpCreate, otp, account, a.mfaEventListeners...)
} else {
details[security.DetailsKeyAuthTime] = time.Now().UTC()
// MFA skipped
for _,p := range account.Permissions() {
permissions[p] = true
}
}
cp := account.CacheableCopy()
auth := usernamePasswordAuthentication{
Acct: cp,
Perms: permissions,
DetailsMap: details,
}
return &auth, nil
}
func (a *Authenticator) translate(err error) error {
switch {
case errors.Is(err, security.ErrorTypeSecurity):
return err
default:
return security.NewAccountStatusError(MessageAccountStatus, err)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
/********************************
MfaVerifyAuthenticator
*********************************/
type MfaVerifyAuthenticator struct {
accountStore security.AccountStore
otpStore OTPManager
mfaEventListeners []MFAEventListenerFunc
checkers []AuthenticationDecisionMaker
postProcessors []PostAuthenticationProcessor
}
func NewMFAVerifyAuthenticator(optionFuncs...AuthenticatorOptionsFunc) *MfaVerifyAuthenticator {
options := AuthenticatorOptions {
MFAEventListeners: []MFAEventListenerFunc{},
}
for _,optFunc := range optionFuncs {
optFunc(&options)
}
return &MfaVerifyAuthenticator{
accountStore: options.AccountStore,
otpStore: options.OTPManager,
mfaEventListeners: options.MFAEventListeners,
checkers: options.Checkers,
postProcessors: options.PostProcessors,
}
}
func (a *MfaVerifyAuthenticator) Authenticate(ctx context.Context, candidate security.Candidate) (auth security.Authentication, err error) {
verify, ok := candidate.(*MFAOtpVerification)
if !ok {
return nil, nil
}
// schedule post processing
ctx = utils.MakeMutableContext(ctx) //nolint:contextcheck
var user security.Account
defer func() {
auth, err = applyPostAuthenticationProcessors(a.postProcessors, ctx, user, candidate, auth, err)
}()
// check if OTP verification should be performed
user, err = checkCurrentAuth(ctx, verify.CurrentAuth, a.accountStore)
if err != nil {
return
}
// pre checks
if e := makeDecision(a.checkers, ctx, verify, user, nil); e != nil {
err = a.translate(e, true)
return
}
// Check OTP
id := verify.CurrentAuth.OTPIdentifier()
switch otp, more, e := a.otpStore.Verify(id, verify.OTP); {
case e != nil:
broadcastMFAEvent(MFAEventVerificationFailure, otp, user, a.mfaEventListeners...)
err = a.translate(e, more)
return
default:
broadcastMFAEvent(MFAEventVerificationSuccess, otp, user, a.mfaEventListeners...)
}
newAuth, e := a.CreateSuccessAuthentication(verify, user)
if e != nil {
err = a.translate(e, true)
return
}
// post checks
if e := makeDecision(a.checkers, ctx, verify, user, newAuth); e != nil {
err = e
return
}
auth = newAuth
return
}
// CreateSuccessAuthentication exported for override posibility
func (a *MfaVerifyAuthenticator) CreateSuccessAuthentication(candidate *MFAOtpVerification, account security.Account) (security.Authentication, error) {
permissions := map[string]interface{}{}
for _,p := range account.Permissions() {
permissions[p] = true
}
details, ok := candidate.CurrentAuth.Details().(map[string]interface{})
if details == nil || !ok {
details = map[string]interface{}{}
if candidate.CurrentAuth.Details() != nil {
details["Literal"] = candidate.CurrentAuth.Details()
}
}
details[security.DetailsKeyAuthTime] = time.Now().UTC()
auth := usernamePasswordAuthentication{
Acct: account,
Perms: permissions,
DetailsMap: details,
}
return &auth, nil
}
func (a *MfaVerifyAuthenticator) translate(err error, more bool) error {
if more {
return security.NewBadCredentialsError(MessageInvalidPasscode, err)
}
switch {
case errors.Is(err, errorCredentialsExpired):
return security.NewCredentialsExpiredError(MessagePasscodeExpired, err)
case errors.Is(err, errorMaxAttemptsReached):
return security.NewMaxAttemptsReachedError(MessageMaxAttemptsReached, err)
default:
return security.NewMaxAttemptsReachedError(MessageInvalidPasscode, err)
}
}
/********************************
MfaVerifyAuthenticator
*********************************/
type MfaRefreshAuthenticator struct {
accountStore security.AccountStore
otpStore OTPManager
mfaEventListeners []MFAEventListenerFunc
checkers []AuthenticationDecisionMaker
postProcessors []PostAuthenticationProcessor
}
func NewMFARefreshAuthenticator(optionFuncs...AuthenticatorOptionsFunc) *MfaRefreshAuthenticator {
options := AuthenticatorOptions {
MFAEventListeners: []MFAEventListenerFunc{},
}
for _,optFunc := range optionFuncs {
optFunc(&options)
}
return &MfaRefreshAuthenticator{
accountStore: options.AccountStore,
otpStore: options.OTPManager,
mfaEventListeners: options.MFAEventListeners,
checkers: options.Checkers,
postProcessors: options.PostProcessors,
}
}
func (a *MfaRefreshAuthenticator) Authenticate(ctx context.Context, candidate security.Candidate) (auth security.Authentication, err error) {
refresh, ok := candidate.(*MFAOtpRefresh)
if !ok {
return nil, nil
}
// schedule post processing
ctx = utils.MakeMutableContext(ctx) //nolint:contextcheck
var user security.Account
defer func() {
auth, err = applyPostAuthenticationProcessors(a.postProcessors, ctx, user, candidate, auth, err)
}()
// check if OTP refresh should be performed
user, err = checkCurrentAuth(ctx, refresh.CurrentAuth, a.accountStore)
if err != nil {
return
}
// pre checks
if e := makeDecision(a.checkers, ctx, refresh, user, nil); e != nil {
err = a.translate(e, true)
return
}
// Refresh OTP
id := refresh.CurrentAuth.OTPIdentifier()
switch otp, more, e := a.otpStore.Refresh(id); {
case e != nil:
err = a.translate(e, more)
return
default:
broadcastMFAEvent(MFAEventOtpRefresh, otp, user, a.mfaEventListeners...)
}
newAuth, e := a.CreateSuccessAuthentication(refresh, user)
if e != nil {
err = a.translate(e, true)
return
}
// post checks
if e := makeDecision(a.checkers, ctx, refresh, user, newAuth); e != nil {
err = e
return
}
auth = newAuth
return
}
// CreateSuccessAuthentication exported for override possibility
func (a *MfaRefreshAuthenticator) CreateSuccessAuthentication(candidate *MFAOtpRefresh, _ security.Account) (security.Authentication, error) {
return candidate.CurrentAuth, nil
}
func (a *MfaRefreshAuthenticator) translate(err error, more bool) error {
if more {
return security.NewBadCredentialsError(MessageCannotRefresh, err)
}
switch {
case errors.Is(err, errorCredentialsExpired):
return security.NewCredentialsExpiredError(MessagePasscodeExpired, err)
case errors.Is(err, errorMaxAttemptsReached):
return security.NewMaxAttemptsReachedError(MessageMaxRefreshAttemptsReached, err)
default:
return security.NewMaxAttemptsReachedError(MessageCannotRefresh, err)
}
}
/************************
Helpers
************************/
func checkCurrentAuth(ctx context.Context, currentAuth UsernamePasswordAuthentication, accountStore security.AccountStore) (security.Account, error) {
if currentAuth == nil {
return nil, security.NewUsernameNotFoundError(MessageInvalidAccountStatus)
}
user, err := accountStore.LoadAccountByUsername(ctx, currentAuth.Username())
if err != nil {
return nil, security.NewUsernameNotFoundError(MessageInvalidAccountStatus, err)
}
return user, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/pquerna/otp"
"sort"
"time"
)
type builderDefaults struct {
accountStore security.AccountStore
passwordEncoder PasswordEncoder
redisClient redis.Client
}
// AuthenticatorBuilder implements security.AuthenticatorBuilder
type AuthenticatorBuilder struct {
feature *PasswordAuthFeature
defaults *builderDefaults
}
func NewAuthenticatorBuilder(f *PasswordAuthFeature, defaults...*builderDefaults) *AuthenticatorBuilder {
builder := &AuthenticatorBuilder{
feature: f,
}
if len(defaults) != 0 {
builder.defaults = defaults[len(defaults) - 1]
} else {
builder.defaults = &builderDefaults{}
}
return builder
}
func (b *AuthenticatorBuilder) Build(_ context.Context) (security.Authenticator, error) {
// prepare options
defaultOpts, err := b.defaultOptions(b.feature)
if err != nil {
return nil, err
}
mfaOpts, err := b.mfaOptions(b.feature)
if err != nil {
return nil, err
}
// username passowrd authenticator
passwdAuth := NewAuthenticator(defaultOpts, mfaOpts)
// MFA
if b.feature.mfaEnabled {
mfaVerify := NewMFAVerifyAuthenticator(defaultOpts, mfaOpts)
mfaRefresh := NewMFARefreshAuthenticator(defaultOpts, mfaOpts)
return security.NewAuthenticator(passwdAuth, mfaVerify, mfaRefresh), nil
}
return passwdAuth, nil
}
func (b *AuthenticatorBuilder) defaultOptions(f *PasswordAuthFeature) (AuthenticatorOptionsFunc, error) {
if f.accountStore == nil {
if b.defaults.accountStore == nil {
return nil, fmt.Errorf("unable to create password authenticator: account accountStore is not set")
}
f.accountStore = b.defaults.accountStore
}
if f.passwordEncoder == nil {
f.passwordEncoder = b.defaults.passwordEncoder
}
decisionMakers := b.prepareDecisionMakers(f)
processors := b.preparePostProcessors(f)
return func(opts *AuthenticatorOptions) {
opts.AccountStore = f.accountStore
if f.passwordEncoder != nil {
opts.PasswordEncoder = f.passwordEncoder
}
opts.Checkers = decisionMakers
opts.PostProcessors = processors
}, nil
}
func (b *AuthenticatorBuilder) mfaOptions(f *PasswordAuthFeature) (AuthenticatorOptionsFunc, error) {
if !f.mfaEnabled {
return func(*AuthenticatorOptions) {/* noop */}, nil
}
if f.otpTTL <= 0 {
f.otpTTL = 10 * time.Minute
}
if f.otpVerifyLimit <= 0 {
f.otpVerifyLimit = 3
}
if f.otpRefreshLimit <= 0 {
f.otpRefreshLimit = 3
}
if f.otpLength <= 3 {
f.otpLength = 3
}
if f.otpSecretSize <= 5 {
f.otpSecretSize = 5
}
otpManager := newTotpManager(func(s *totpManager) {
s.ttl = f.otpTTL
s.maxVerifyLimit = f.otpVerifyLimit
s.maxRefreshLimit = f.otpRefreshLimit
if b.defaults.redisClient != nil {
s.store = newRedisOtpStore(b.defaults.redisClient)
}
s.factory = newTotpFactory(func(factory *totpFactory) {
factory.digits = otp.Digits(f.otpLength)
factory.secretSize = int(f.otpSecretSize)
})
})
decisionMakers := b.prepareDecisionMakers(f)
processors := b.preparePostProcessors(f)
return func(opts *AuthenticatorOptions) {
opts.OTPManager = otpManager
sort.SliceStable(f.mfaEventListeners, func(i,j int) bool {
return order.OrderedFirstCompare(f.mfaEventListeners[i], f.mfaEventListeners[j])
})
opts.MFAEventListeners = f.mfaEventListeners
opts.Checkers = decisionMakers
opts.PostProcessors = processors
}, nil
}
func (b *AuthenticatorBuilder) prepareDecisionMakers(f *PasswordAuthFeature) []AuthenticationDecisionMaker {
// maybe customizable via Feature
acctStatusChecker := NewAccountStatusChecker(f.accountStore)
passwordChecker := NewPasswordPolicyChecker(f.accountStore)
return []AuthenticationDecisionMaker{
PreCredentialsCheck(acctStatusChecker),
FinalCheck(passwordChecker),
}
}
func (b *AuthenticatorBuilder) preparePostProcessors(f *PasswordAuthFeature) []PostAuthenticationProcessor {
// maybe customizable via Feature
return []PostAuthenticationProcessor{
NewPersistAccountPostProcessor(f.accountStore),
NewAdditionalDetailsPostProcessor(),
NewAccountStatusPostProcessor(f.accountStore),
NewAccountLockingPostProcessor(f.accountStore),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
)
var (
PasswordAuthenticatorFeatureId = security.FeatureId("passwdAuth", security.FeatureOrderAuthenticator)
)
type PasswordAuthConfigurer struct {
accountStore security.AccountStore
passwordEncoder PasswordEncoder
redisClient redis.Client
}
func newPasswordAuthConfigurer(store security.AccountStore, encoder PasswordEncoder, redisClient redis.Client) *PasswordAuthConfigurer {
return &PasswordAuthConfigurer {
accountStore: store,
passwordEncoder: encoder,
redisClient: redisClient,
}
}
func (pac *PasswordAuthConfigurer) Apply(feature security.Feature, ws security.WebSecurity) error {
// Verify
if err := pac.validate(feature.(*PasswordAuthFeature), ws); err != nil {
return err
}
f := feature.(*PasswordAuthFeature)
// Build authenticator
ctx := context.Background()
defaults := &builderDefaults{
accountStore: pac.accountStore,
passwordEncoder: pac.passwordEncoder,
redisClient: pac.redisClient,
}
authenticator, err := NewAuthenticatorBuilder(f, defaults).Build(ctx)
if err != nil {
return err
}
// Add authenticator to WS, flatten if multiple
if composite, ok := authenticator.(*security.CompositeAuthenticator); ok {
ws.Authenticator().(*security.CompositeAuthenticator).Merge(composite)
} else {
ws.Authenticator().(*security.CompositeAuthenticator).Add(authenticator)
}
return nil
}
func (pac *PasswordAuthConfigurer) validate(f *PasswordAuthFeature, ws security.WebSecurity) error {
if _,ok := ws.Authenticator().(*security.CompositeAuthenticator); !ok {
return fmt.Errorf("unable to add password authenticator to %T", ws.Authenticator())
}
if f.accountStore == nil && pac.accountStore == nil {
return fmt.Errorf("unable to create password authenticator: account accountStore is not set")
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"encoding/gob"
"github.com/cisco-open/go-lanai/pkg/security"
"time"
)
const (
SpecialPermissionMFAPending = "MFAPending"
SpecialPermissionOtpId = "OtpId"
)
const (
MessageUserNotFound = "Mismatched Username and Password"
MessageBadCredential = "Mismatched Username and Password"
MessageOtpNotAvailable = "MFA required but temprorily unavailable"
MessageAccountStatus = "Inactive Account"
MessageInvalidPasscode = "Bad Verification Code"
MessagePasscodeExpired = "Verification Code Expired"
MessageCannotRefresh = "Unable to Refresh"
MessageMaxAttemptsReached = "No More Verification Attempts Allowed"
MessageMaxRefreshAttemptsReached = "No More Resend Attempts Allowed"
MessageInvalidAccountStatus = "Issue with current account status"
MessageAccountDisabled = "Account Disabled"
MessageAccountLocked = "Account Locked"
MessagePasswordLoginNotAllowed = "Password Login not Allowed"
MessageLockedDueToBadCredential = "Mismatched Username and Password. Account locked due to too many failed attempts"
MessagePasswordExpired = "User credentials have expired"
)
// For error translation
var (
errorBadCredentials = security.NewBadCredentialsError("bad creds")
errorCredentialsExpired = security.NewCredentialsExpiredError("cred exp")
errorMaxAttemptsReached = security.NewMaxAttemptsReachedError("max attempts")
//errorAccountStatus = security.NewAccountStatusError("acct status")
)
/******************************
Serialization
******************************/
func GobRegister() {
gob.Register((*usernamePasswordAuthentication)(nil))
gob.Register((*timeBasedOtp)(nil))
gob.Register(TOTP{})
gob.Register(time.Time{})
gob.Register(time.Duration(0))
}
/************************
security.Candidate
************************/
type MFAMode int
const (
MFAModeSkip MFAMode = iota
MFAModeOptional
MFAModeMust
)
// UsernamePasswordPair is the supported security.Candidate of this authenticator
type UsernamePasswordPair struct {
Username string
Password string
DetailsMap map[string]interface{}
EnforceMFA MFAMode
}
// Principal implements security.Candidate
func (upp *UsernamePasswordPair) Principal() interface{} {
return upp.Username
}
// Credentials implements security.Candidate
func (upp *UsernamePasswordPair) Credentials() interface{} {
return upp.Password
}
// Details implements security.Candidate
func (upp *UsernamePasswordPair) Details() interface{} {
return upp.DetailsMap
}
// MFAOtpVerification is the supported security.Candidate for MFA authentication
type MFAOtpVerification struct {
CurrentAuth UsernamePasswordAuthentication
OTP string
DetailsMap map[string]interface{}
}
// Principal implements security.Candidate
func (uop *MFAOtpVerification) Principal() interface{} {
return uop.CurrentAuth.Principal()
}
// Credentials implements security.Candidate
func (uop *MFAOtpVerification) Credentials() interface{} {
return uop.OTP
}
// Details implements security.Candidate
func (uop *MFAOtpVerification) Details() interface{} {
return uop.DetailsMap
}
// MFAOtpRefresh is the supported security.Candidate for MFA OTP refresh
type MFAOtpRefresh struct {
CurrentAuth UsernamePasswordAuthentication
DetailsMap map[string]interface{}
}
// Principal implements security.Candidate
func (uop *MFAOtpRefresh) Principal() interface{} {
return uop.CurrentAuth.Principal()
}
// Credentials implements security.Candidate
func (uop *MFAOtpRefresh) Credentials() interface{} {
return uop.CurrentAuth.OTPIdentifier()
}
// Details implements security.Candidate
func (uop *MFAOtpRefresh) Details() interface{} {
return uop.DetailsMap
}
/******************************
security.Authentication
******************************/
// UsernamePasswordAuthentication implements security.Authentication
type UsernamePasswordAuthentication interface {
security.Authentication
Username() string
IsMFAPending() bool
OTPIdentifier() string
}
// TODO: do we want the details here to also implement the ctx_details interfaces?
// usernamePasswordAuthentication
// Note: all fields should not be used directly. It's exported only because gob only deal with exported field
type usernamePasswordAuthentication struct {
Acct security.Account
Perms map[string]interface{}
DetailsMap map[string]interface{}
}
func (auth *usernamePasswordAuthentication) Principal() interface{} {
return auth.Acct
}
func (auth *usernamePasswordAuthentication) Permissions() security.Permissions {
return auth.Perms
}
func (auth *usernamePasswordAuthentication) State() security.AuthenticationState {
switch {
case auth.IsMFAPending():
return security.StatePrincipalKnown
default:
return security.StateAuthenticated
}
}
func (auth *usernamePasswordAuthentication) Details() interface{} {
return auth.DetailsMap
}
func (auth *usernamePasswordAuthentication) Username() string {
return auth.Acct.Username()
}
func (auth *usernamePasswordAuthentication) IsMFAPending() bool {
_, ok := auth.Permissions()[SpecialPermissionOtpId].(string)
return ok
}
func (auth *usernamePasswordAuthentication) OTPIdentifier() string {
v, ok := auth.Permissions()[SpecialPermissionOtpId].(string)
if ok {
return v
}
return ""
}
func IsSamePrincipal(username string, currentAuth security.Authentication) bool {
if currentAuth == nil || currentAuth.State() < security.StatePrincipalKnown {
return false
}
if account, ok := currentAuth.Principal().(security.Account); ok && username == account.Username() {
return true
} else if principal, ok := currentAuth.Principal().(string); ok && username == principal {
return true
}
return false
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import "golang.org/x/crypto/bcrypt"
type PasswordEncoder interface {
Encode(rawPassword string) string
Matches(raw, encoded string) bool
}
type noopPasswordEncoder string
func NewNoopPasswordEncoder() PasswordEncoder {
return noopPasswordEncoder("clear text")
}
func (noopPasswordEncoder) Encode(rawPassword string) string {
return rawPassword
}
func (noopPasswordEncoder) Matches(raw, encoded string) bool {
return raw == encoded
}
// bcryptPasswordEncoder implements PasswordEncoder
type bcryptPasswordEncoder struct {
cost int
}
func NewBcryptPasswordEncoder() PasswordEncoder {
return &bcryptPasswordEncoder{
cost: 10,
}
}
func (enc *bcryptPasswordEncoder) Encode(raw string) string {
encoded, e := bcrypt.GenerateFromPassword([]byte(raw), enc.cost)
if e != nil {
return ""
}
return string(encoded)
}
func (enc *bcryptPasswordEncoder) Matches(raw, encoded string) bool {
e := bcrypt.CompareHashAndPassword([]byte(encoded), []byte(raw))
return e == nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"time"
)
/******************************
abstracts
******************************/
// AuthenticationDecisionMaker is invoked at various stages of authentication decision making process.
// If AuthenticationDecisionMaker implement order.Ordered interface, its order is respected using order.OrderedFirstCompare.
// This means highest priority is executed first and non-ordered decision makers run at last.
//
// Note: each AuthenticationDecisionMaker will get invoked multiple times during the authentication process.
// So implementations should check stage before making desisions. Or use ConditionalDecisionMaker
type AuthenticationDecisionMaker interface {
// Decide makes decision on whether the Authenticator should approve the auth request.
// the returned error indicate the reason of rejection. returns nil when approved
// - The security.Authentication is nil when credentials has not been validated (pre check)
// - The security.Authentication is non-nil when credentials has been validated (post check).
// The non-nil value is the proposed authentication to be returned by Authenticator
//
// If any of input parameters are mutable, AuthenticationDecisionMaker is allowed to change it
Decide(context.Context, security.Candidate, security.Account, security.Authentication) error
}
/******************************
Common Implementation
******************************/
type DecisionMakerConditionFunc func(context.Context, security.Candidate, security.Account, security.Authentication) bool
// ConditionalDecisionMaker implements AuthenticationDecisionMaker with ability to skip based on condiitons
type ConditionalDecisionMaker struct {
delegate AuthenticationDecisionMaker
condition DecisionMakerConditionFunc
}
func (dm *ConditionalDecisionMaker) Decide(ctx context.Context, c security.Candidate, acct security.Account, auth security.Authentication) error {
if dm.delegate == nil || dm.condition != nil && !dm.condition(ctx, c, acct, auth) {
return nil
}
return dm.delegate.Decide(ctx, c, acct, auth)
}
func PreCredentialsCheck(delegate AuthenticationDecisionMaker) AuthenticationDecisionMaker {
return &ConditionalDecisionMaker{
delegate: delegate,
condition: isPreCredentialsCheck,
}
}
func PostCredentialsCheck(delegate AuthenticationDecisionMaker) AuthenticationDecisionMaker {
return &ConditionalDecisionMaker{
delegate: delegate,
condition: isPostCredentialsCheck,
}
}
func FinalCheck(delegate AuthenticationDecisionMaker) AuthenticationDecisionMaker {
return &ConditionalDecisionMaker{
delegate: delegate,
condition: isFinalStage,
}
}
/******************************
helpers
******************************/
func makeDecision(checkers []AuthenticationDecisionMaker, ctx context.Context, can security.Candidate, acct security.Account, auth security.Authentication) error {
for _, checker := range checkers {
if err := checker.Decide(ctx, can, acct, auth); err != nil {
return err
}
}
return nil
}
func isPreCredentialsCheck(_ context.Context, _ security.Candidate, _ security.Account, auth security.Authentication) bool {
return auth == nil
}
func isPostCredentialsCheck(_ context.Context, _ security.Candidate, _ security.Account, auth security.Authentication) bool {
return auth != nil
}
//func isPreMFAVerify(_ context.Context, can security.Candidate, _ security.Account, auth security.Authentication) bool {
// if auth != nil {
// return false
// }
//
// if _, isMFAVerify := can.(*MFAOtpVerification); isMFAVerify {
// return true
// }
//
// _, isMFARefresh := can.(*MFAOtpRefresh)
// return isMFARefresh
//}
//func isPostMFAVerify(_ context.Context, can security.Candidate, _ security.Account, auth security.Authentication) bool {
// if auth == nil {
// return false
// }
//
// if _, isMFAVerify := can.(*MFAOtpVerification); isMFAVerify {
// return true
// }
//
// _, isMFARefresh := can.(*MFAOtpRefresh)
// return isMFARefresh
//}
func isFinalStage(_ context.Context, can security.Candidate, _ security.Account, auth security.Authentication) bool {
return auth != nil && auth.State() >= security.StateAuthenticated
}
/******************************
Common Checks
******************************/
// AccountStatusChecker check account status and also auto unlock account if locking rules allows
type AccountStatusChecker struct {
store security.AccountStore
}
func NewAccountStatusChecker(store security.AccountStore) *AccountStatusChecker {
return &AccountStatusChecker{store: store}
}
func (adm *AccountStatusChecker) Decide(ctx context.Context, _ security.Candidate, acct security.Account, _ security.Authentication) error {
if acct == nil {
return nil
}
switch {
case acct.Disabled():
return security.NewAccountStatusError(MessageAccountDisabled)
case acct.Type() == security.AccountTypeFederated:
return security.NewAccountStatusError(MessagePasswordLoginNotAllowed)
case acct.Locked():
return adm.decideAutoUnlock(ctx, acct)
default:
return nil
}
}
func (adm *AccountStatusChecker) decideAutoUnlock(ctx context.Context, acct security.Account) (err error) {
if !acct.Locked() {
return nil
}
err = security.NewAccountStatusError(MessageAccountLocked)
history, hok := acct.(security.AccountHistory)
updater, uok := acct.(security.AccountUpdater)
if !hok || !uok || history.LockoutTime().IsZero() {
return
}
rules, e := adm.store.LoadLockingRules(ctx, acct)
if e != nil || rules == nil || !rules.LockoutEnabled() || rules.LockoutDuration() <= 0 {
return
}
if time.Now().After(history.LockoutTime().Add(rules.LockoutDuration()) ) {
updater.UnlockAccount()
logger.WithContext(ctx).Infof("Account[%s] Unlocked", acct.Username())
}
if !acct.Locked() {
return nil
}
return
}
// PasswordPolicyChecker takes account password policy into consideration
type PasswordPolicyChecker struct {
store security.AccountStore
}
func NewPasswordPolicyChecker(store security.AccountStore) *PasswordPolicyChecker {
return &PasswordPolicyChecker{store: store}
}
func (c *PasswordPolicyChecker) Decide(ctx context.Context, _ security.Candidate, acct security.Account, auth security.Authentication) error {
history, hok := acct.(security.AccountHistory)
_, uok := acct.(security.AccountUpdater)
if !hok || !uok {
return nil
}
policy, e := c.store.LoadPwdAgingRules(ctx, acct)
if e != nil || policy == nil || !policy.PwdAgingRuleEnforced() || policy.PwdMaxAge() <= 0 {
return nil //nolint:nilerr // for now, we ignore this error and treat it as if policy is not enabled
}
switch {
case history.PwdChangedTime().Add(policy.PwdMaxAge()).Before(time.Now()):
return c.decideExpiredPassword(ctx, acct, policy, auth)
default:
return c.decideNonExpiredPassword(ctx, acct, policy, auth)
}
}
func (c *PasswordPolicyChecker) decideNonExpiredPassword(
_ context.Context, acct security.Account, policy security.AccountPwdAgingRule, auth security.Authentication) (err error) {
// reset graceful auth
acct.(security.AccountUpdater).ResetGracefulAuthCount()
// check if expiring soon
toExpire := policy.PwdMaxAge() - time.Now().Sub(acct.(security.AccountHistory).PwdChangedTime())
if toExpire >= 0 && toExpire < policy.PwdExpiryWarningPeriod() {
c.addWarning(auth, fmt.Sprintf("Password is expring in %s", toExpire.String()))
}
return nil
}
func (c *PasswordPolicyChecker) decideExpiredPassword(
_ context.Context, acct security.Account, policy security.AccountPwdAgingRule, auth security.Authentication) error {
switch remaining := policy.GracefulAuthLimit() - acct.(security.AccountHistory).GracefulAuthCount(); {
case remaining <= 0:
// No more graceful auth
return security.NewCredentialsExpiredError(MessagePasswordExpired)
case remaining == 1:
// Last chance
c.addWarning(auth, "Last Graceful Login")
default:
// more chance available
c.addWarning(auth, fmt.Sprintf("%d Graceful Login Left", remaining))
}
acct.(security.AccountUpdater).IncrementGracefulAuthCount()
return nil
}
func (c *PasswordPolicyChecker) addWarning(auth security.Authentication, warning interface{}) {
details, ok := auth.Details().(map[string]interface{})
if !ok || details == nil {
return
}
var existing []interface{}
switch w := details[security.DetailsKeyAuthWarning]; w.(type) {
case nil:
existing = []interface{}{}
case []interface{}:
existing = w.([]interface{})
default:
existing = []interface{}{w}
}
if warnings, ok := warning.([]interface{}); ok {
details[security.DetailsKeyAuthWarning] = append(existing, warnings...)
} else {
details[security.DetailsKeyAuthWarning] = append(existing, warning)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"time"
)
type PasswordAuthFeature struct {
accountStore security.AccountStore
passwordEncoder PasswordEncoder
// MFA support
mfaEnabled bool
mfaEventListeners []MFAEventListenerFunc
otpTTL time.Duration
otpVerifyLimit uint
otpRefreshLimit uint
otpLength uint
otpSecretSize uint
}
// Configure is Standard security.Feature entrypoint
func Configure(ws security.WebSecurity) *PasswordAuthFeature {
feature := &PasswordAuthFeature{}
if fm, ok := ws.(security.FeatureModifier); ok {
return fm.Enable(feature).(*PasswordAuthFeature)
}
panic(fmt.Errorf("unable to configure session: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// New is Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *PasswordAuthFeature {
return &PasswordAuthFeature{}
}
func (f *PasswordAuthFeature) Identifier() security.FeatureIdentifier {
return PasswordAuthenticatorFeatureId
}
func (f *PasswordAuthFeature) AccountStore(as security.AccountStore) *PasswordAuthFeature {
f.accountStore = as
return f
}
func (f *PasswordAuthFeature) PasswordEncoder(pe PasswordEncoder) *PasswordAuthFeature {
f.passwordEncoder = pe
return f
}
func (f *PasswordAuthFeature) MFA(enabled bool) *PasswordAuthFeature {
f.mfaEnabled = enabled
return f
}
func (f *PasswordAuthFeature) MFAEventListeners(handlers ...MFAEventListenerFunc) *PasswordAuthFeature {
f.mfaEventListeners = append(f.mfaEventListeners, handlers...)
return f
}
func (f *PasswordAuthFeature) OtpTTL(ttl time.Duration) *PasswordAuthFeature {
f.otpTTL = ttl
return f
}
func (f *PasswordAuthFeature) OtpVerifyLimit(v uint) *PasswordAuthFeature {
f.otpVerifyLimit = v
return f
}
func (f *PasswordAuthFeature) OtpRefreshLimit(v uint) *PasswordAuthFeature {
f.otpRefreshLimit = v
return f
}
func (f *PasswordAuthFeature) OtpLength(v uint) *PasswordAuthFeature {
f.otpLength = v
return f
}
func (f *PasswordAuthFeature) OtpSecretSize(v uint) *PasswordAuthFeature {
f.otpSecretSize = v
return f
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import "github.com/cisco-open/go-lanai/pkg/security"
type MFAEvent int
const (
_ = iota
MFAEventOtpCreate
MFAEventOtpRefresh
MFAEventVerificationSuccess
MFAEventVerificationFailure
)
type MFAEventListenerFunc func(event MFAEvent, otp OTP, principal interface{})
/*****************************
Common Implements
*****************************/
func broadcastMFAEvent(event MFAEvent, otp OTP, account security.Account, listeners... MFAEventListenerFunc) {
for _,listener := range listeners {
listener(event, otp, account)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"bytes"
"context"
"encoding/gob"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/google/uuid"
"github.com/pkg/errors"
"io"
"strings"
"time"
)
const (
redisKeyPrefixOtp = "OTP-"
)
type OTP interface {
ID() string
Passcode() string
TTL() time.Duration
Expire() time.Time
Attempts() uint
Refreshes() uint
IncrementAttempts()
IncrementRefreshes()
secret() string
}
type OTPManager interface {
// New create new OTP and save it
New() (OTP, error)
// Get loads OTP by Domain
Get(id string) (OTP, error)
// Verify use Get to load OTP and check the given passcode against the loaded OTP.
// It returns the loaded OTP regardless the verification result.
// It returns false if it reaches maximum attempts limit. otherwise returns true
// error parameter indicate wether the given passcode is valid. It's nil if it's valid
Verify(id, passcode string) (loaded OTP, hasMoreChances bool, err error)
// Refresh regenerate OTP passcode without changing secret and Domain
// It returns the loaded or refreshed OTP regardless the verification result.
// It returns false if it reaches maximum attempts limit. otherwise returns true
// error parameter indicate wether the passcode is refreshed
Refresh(id string) (refreshed OTP, hasMoreChances bool, err error)
// Delete delete OTP by Domain
Delete(id string) error
}
type OTPStore interface {
Save(OTP) error
Load(id string) (OTP, error)
Delete(id string) error
}
/*****************************
Common Implements
*****************************/
// timeBasedOtp implements OTP
type timeBasedOtp struct {
Identifier string
Value TOTP
AttemptCount uint
RefreshCount uint
}
func (v *timeBasedOtp) secret() string {
return v.Value.Secret
}
func (v *timeBasedOtp) ID() string {
return v.Identifier
}
func (v *timeBasedOtp) Passcode() string {
return v.Value.Passcode
}
func (v *timeBasedOtp) TTL() time.Duration {
return v.Value.TTL
}
func (v *timeBasedOtp) Expire() time.Time {
return v.Value.Expire
}
func (v *timeBasedOtp) Attempts() uint {
return v.AttemptCount
}
func (v *timeBasedOtp) Refreshes() uint {
return v.RefreshCount
}
func (v *timeBasedOtp) IncrementAttempts() {
v.AttemptCount++
}
func (v *timeBasedOtp) IncrementRefreshes() {
v.RefreshCount++
}
// totpManager implements OTPManager
type totpManager struct {
factory TOTPFactory
store OTPStore
ttl time.Duration
maxVerifyLimit uint
maxRefreshLimit uint
}
type totpManagerOptionsFunc func(*totpManager)
func newTotpManager(options ...totpManagerOptionsFunc) *totpManager {
manager := &totpManager{
store: inmemOtpStore(make(map[string]OTP)),
ttl: time.Minute * 10,
maxVerifyLimit: 3,
maxRefreshLimit: 3,
}
for _, opt := range options {
opt(manager)
}
return manager
}
func (m *totpManager) New() (OTP, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrapf(err, "Unable to create TOTP")
}
value, err := m.factory.Generate(m.ttl)
if err != nil {
return nil, errors.Wrapf(err, "Unable to create TOTP")
}
otp := &timeBasedOtp{
Identifier: id.String(),
Value: value,
}
// save
if err := m.store.Save(otp); err != nil {
return nil, err
}
return otp, nil
}
func (m *totpManager) Get(id string) (OTP, error) {
otp, err := m.store.Load(id)
if err != nil {
return nil, err
}
return otp, nil
}
func (m *totpManager) Verify(id, passcode string) (loaded OTP, hasMoreChances bool, err error) {
// load OTP by Domain
otp, e := m.store.Load(id)
if otp == nil || e != nil {
return nil, false, security.NewCredentialsExpiredError("Passcode already expired", e)
}
// schedule for post verification
defer m.cleanup(otp)
// check verification attempts
if otp.IncrementAttempts(); otp.Attempts() > m.maxVerifyLimit {
return nil, false, security.NewMaxAttemptsReachedError("Max verification attempts exceeded")
}
toValidate := TOTP{
Passcode: passcode,
Secret: otp.secret(),
TTL: otp.TTL(),
Expire: time.Now().Add(otp.TTL()),
}
loaded = otp
hasMoreChances = otp.Attempts() < m.maxVerifyLimit
if valid, e := m.factory.Validate(toValidate); e != nil || !valid {
if hasMoreChances {
err = security.NewBadCredentialsError("Passcode doesn't match", e)
} else {
err = security.NewMaxAttemptsReachedError("Passcode doesn't match and max verification attempts exceeded")
}
}
return
}
func (m *totpManager) Refresh(id string) (loaded OTP, hasMoreChances bool, err error) {
// load OTP by id
loaded, e := m.store.Load(id)
if e != nil {
return nil, false, security.NewCredentialsExpiredError("Passcode expired", e)
}
otp, ok := loaded.(*timeBasedOtp)
if !ok {
return nil, false, security.NewCredentialsExpiredError("Passcode expired", e)
}
// schedule for post refresh
defer m.cleanup(otp)
// check refresh attempts
if otp.IncrementRefreshes(); otp.Refreshes() > m.maxRefreshLimit {
return loaded, false, security.NewMaxAttemptsReachedError("Max refresh/resend attempts exceeded")
}
// calculate remining time
ttl := otp.Expire().Sub(time.Now())
if ttl <= 0 {
return loaded, false, security.NewCredentialsExpiredError("Passcode already expired")
}
// do refresh
hasMoreChances = otp.Refreshes() < m.maxRefreshLimit
refreshed, e := m.factory.Refresh(otp.secret(), ttl)
if e != nil {
if hasMoreChances {
return loaded, hasMoreChances, security.NewAuthenticationError("Unable to refresh/resend passcode", e)
} else {
return loaded, hasMoreChances, security.NewMaxAttemptsReachedError("Unable to refresh/resend passcode and max refresh/resend attempts exceeded", e)
}
}
otp.Value = refreshed
return
}
func (m *totpManager) Delete(id string) error {
return m.store.Delete(id)
}
func (m *totpManager) cleanup(otp OTP) {
if time.Now().After(otp.Expire()) {
// expired try to delete the record
_ = m.store.Delete(otp.ID())
} else {
// not expired, save it
_ = m.store.Save(otp)
}
}
// inmemOtpStore implements OTPStore
type inmemOtpStore map[string]OTP
func (s inmemOtpStore) Save(otp OTP) error {
s[otp.ID()] = otp
return nil
}
func (s inmemOtpStore) Load(id string) (OTP, error) {
if otp, ok := s[id]; ok {
return otp, nil
}
return nil, fmt.Errorf("not found with id %s", id)
}
func (s inmemOtpStore) Delete(id string) error {
if _, ok := s[id]; ok {
delete(s, id)
return nil
}
return fmt.Errorf("not found with id %s", id)
}
// redisOtpStore implements OTPStore
type redisOtpStore struct {
redisClient redis.Client
}
func newRedisOtpStore(redisClient redis.Client) *redisOtpStore {
return &redisOtpStore{
redisClient: redisClient,
}
}
func (s *redisOtpStore) Save(otp OTP) error {
bytes, err := serialize(otp)
if err != nil {
return err
}
key := s.key(otp.ID())
ttl := otp.Expire().Sub(time.Now())
cmd := s.redisClient.Set(context.Background(), key, bytes, ttl)
return cmd.Err()
}
func (s *redisOtpStore) Load(id string) (OTP, error) {
key := s.key(id)
cmd := s.redisClient.Get(context.Background(), key)
val, err := cmd.Result()
if err != nil {
return nil, err
}
return deserialize(strings.NewReader(val))
}
func (s *redisOtpStore) Delete(id string) error {
key := s.key(id)
cmd := s.redisClient.Del(context.Background(), key)
return cmd.Err()
}
func (s *redisOtpStore) key(id string) string {
return redisKeyPrefixOtp + id
}
func serialize(otp OTP) ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(&otp); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func deserialize(src io.Reader) (OTP, error) {
dec := gob.NewDecoder(src)
var otp OTP
if err := dec.Decode(&otp); err != nil {
return nil, err
}
return otp, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
var logger = log.New("SEC.Passwd")
var Module = &bootstrap.Module{
Name: "passwd authenticator",
Precedence: security.MinSecurityPrecedence + 30,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
AccountStore security.AccountStore `optional:"true"`
PasswordEncoder PasswordEncoder `optional:"true"`
Redis redis.Client `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newPasswordAuthConfigurer(di.AccountStore, di.PasswordEncoder, di.Redis)
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(PasswordAuthenticatorFeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"time"
)
const (
postProcessorOrderAccountStatus = order.Lowest
postProcessorOrderAccountLocking = 0
postProcessorOrderAdditionalDetails = order.Highest + 1
postProcessorOrderPersistAccount = order.Highest
)
/******************************
abstracts
******************************/
// AuthenticationResult is a values carrier for PostAuthenticationProcessor
type AuthenticationResult struct {
Candidate security.Candidate
Auth security.Authentication
Error error
}
// PostAuthenticationProcessor is invoked at the end of authentication process regardless of authentication decisions (granted or rejected)
// If PostAuthenticationProcessor implement order.Ordered interface, its order is respected using order.OrderedFirstCompareReverse.
// This means highest priority is executed last
type PostAuthenticationProcessor interface {
// Process is invoked at the end of authentication process by the Authenticator to perform post-auth action.
// The method is invoked regardless if the authentication is granted:
// - If the authentication is granted, the AuthenticationResult.Auth is non-nil and AuthenticationResult.Error is nil
// - If the authentication is rejected, the AuthenticationResult.Error is non-nil and AuthenticationResult.Auth should be ignored
//
// If the context.Context and security.Account parameters are mutable, PostAuthenticationProcessor is allowed to change it
// Note: PostAuthenticationProcessor typically shouldn't overwrite authentication decision (rejected or approved)
// However, it is allowed to modify result by returning different AuthenticationResult.
// This is useful when PostAuthenticationProcessor want to returns different error or add more details to authentication
Process(context.Context, security.Account, AuthenticationResult) AuthenticationResult
}
/******************************
Helpers
******************************/
func applyPostAuthenticationProcessors(processors []PostAuthenticationProcessor,
ctx context.Context, acct security.Account, can security.Candidate, auth security.Authentication, err error) (security.Authentication, error) {
result := AuthenticationResult{
Candidate: can,
Auth: auth,
Error: err,
}
for _, processor := range processors {
result = processor.Process(ctx, acct, result)
}
return result.Auth, result.Error
}
/******************************
Common Implementation
******************************/
// PersistAccountPostProcessor saves Account. It's implement order.Ordered with highest priority
// Note: post-processors executed in reversed order
type PersistAccountPostProcessor struct {
store security.AccountStore
}
func NewPersistAccountPostProcessor(store security.AccountStore) *PersistAccountPostProcessor {
return &PersistAccountPostProcessor{store: store}
}
// Order the processor run last
func (p *PersistAccountPostProcessor) Order() int {
return postProcessorOrderPersistAccount
}
func (p *PersistAccountPostProcessor) Process(ctx context.Context, acct security.Account, result AuthenticationResult) AuthenticationResult {
if acct == nil {
return result
}
// regardless decision, account need to be persisted in case of any status changes.
// Note: we ignore save error since it's too late to do anything
e := p.store.Save(ctx, acct)
if e != nil && !errors.Is(e, security.ErrorSubTypeInternalError) {
logger.WithContext(ctx).Warnf("account status was not persisted due to error: %v", e)
}
return result
}
// AccountStatusPostProcessor updates account based on authentication result.
// It could update last login status, failed login status, etc.
type AccountStatusPostProcessor struct {
store security.AccountStore
}
func NewAccountStatusPostProcessor(store security.AccountStore) *AccountStatusPostProcessor {
return &AccountStatusPostProcessor{store: store}
}
// Order the processor run first (reversed ordering)
func (p *AccountStatusPostProcessor) Order() int {
return postProcessorOrderAccountStatus
}
func (p *AccountStatusPostProcessor) Process(ctx context.Context, acct security.Account, result AuthenticationResult) AuthenticationResult {
updater, ok := acct.(security.AccountUpdater)
if !ok {
return result
}
switch {
case result.Error == nil && result.Auth != nil && result.Auth.State() >= security.StateAuthenticated:
// fully authenticated
updater.RecordSuccess(time.Now())
updater.ResetFailedAttempts()
if history, ok := acct.(security.AccountHistory); ok && history.SerialFailedAttempts() != 0 {
logger.WithContext(ctx).Warnf("Account [%s] failed to reset", acct.Username())
}
case errors.Is(result.Error, errorBadCredentials) && isPasswordAuth(result):
// Password auth failed with incorrect password
limit := 5
if rules, e := p.store.LoadLockingRules(ctx, acct); e == nil && rules != nil && rules.LockoutEnabled() {
limit = rules.LockoutFailuresLimit()
}
updater.RecordFailure(time.Now(), limit)
default:
}
return result
}
// AccountLockingPostProcessor react on failed authentication. Lock account if necessary
type AccountLockingPostProcessor struct {
store security.AccountStore
}
func NewAccountLockingPostProcessor(store security.AccountStore) *AccountLockingPostProcessor {
return &AccountLockingPostProcessor{store: store}
}
// Order the processor run between AccountStatusPostProcessor and PersistAccountPostProcessor
func (p *AccountLockingPostProcessor) Order() int {
return postProcessorOrderAccountLocking
}
func (p *AccountLockingPostProcessor) Process(ctx context.Context, acct security.Account, result AuthenticationResult) AuthenticationResult {
// skip if
// 1. account is not updatable
// 2. not bad credentials
// 3. not password auth
updater, ok := acct.(security.AccountUpdater)
if !ok || !errors.Is(result.Error, errorBadCredentials) || !isPasswordAuth(result) {
return result
}
history, ok := acct.(security.AccountHistory)
rules, e := p.store.LoadLockingRules(ctx, acct)
if !ok || e != nil || rules == nil || !rules.LockoutEnabled() {
return result
}
// Note 1: we assume AccountStatusPostProcessor already updated login success/failure records
// Note 2: we don't count login failure before last lockout time. whether this is necessary is TBD
// find first login failure within FailureInterval
now := time.Now()
count := 0
for _, t := range history.LoginFailures() {
if interval := now.Sub(t); interval <= rules.LockoutFailuresInterval() && t.After(history.LockoutTime()) {
count++
}
}
// lock the account if over the limit
if count >= rules.LockoutFailuresLimit() {
updater.LockAccount()
logger.WithContext(ctx).Infof("Account[%s] Locked", acct.Username())
// Optional, change error message
result.Error = security.NewAccountStatusError(MessageLockedDueToBadCredential, result.Error)
}
return result
}
// AdditionalDetailsPostProcessor populate additional authentication details if the authentication is granted.
// It's implement order.Ordered
// Note: post-processors executed in reversed order
type AdditionalDetailsPostProcessor struct {}
func NewAdditionalDetailsPostProcessor() *AdditionalDetailsPostProcessor {
return &AdditionalDetailsPostProcessor{}
}
// Order the processor run last
func (p *AdditionalDetailsPostProcessor) Order() int {
return postProcessorOrderAdditionalDetails
}
func (p *AdditionalDetailsPostProcessor) Process(_ context.Context, _ security.Account, result AuthenticationResult) AuthenticationResult {
if result.Error != nil || result.Auth == nil {
return result
}
details, ok := result.Auth.Details().(map[string]interface{})
if !ok {
return result
}
// auth method
details[security.DetailsKeyAuthMethod] = security.AuthMethodPassword
// MFA
if isMfaVerify(result) {
details[security.DetailsKeyMFAApplied] = true
}
return result
}
/******************************
Helper
******************************/
func isPasswordAuth(result AuthenticationResult) bool {
_, ok := result.Candidate.(*UsernamePasswordPair)
return ok
}
func isMfaVerify(result AuthenticationResult) bool {
_, ok := result.Candidate.(*MFAOtpVerification)
return ok
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package passwd
import (
"crypto/rand"
"encoding/base32"
"fmt"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"time"
)
var b32NoPadding = base32.StdEncoding.WithPadding(base32.NoPadding)
type TOTP struct {
Passcode string
Secret string
TTL time.Duration
Expire time.Time
}
type TOTPFactory interface {
Generate(ttl time.Duration) (totp TOTP, err error)
Refresh(secret string, ttl time.Duration) (totp TOTP, err error)
Validate(totp TOTP) (valid bool, err error)
}
type totpFactory struct {
skew uint
digits otp.Digits
alg otp.Algorithm
secretSize int
}
type totpFactoryOptions func(*totpFactory)
func newTotpFactory(options...totpFactoryOptions) *totpFactory {
factory := &totpFactory{
skew: 0,
digits: 6,
alg: otp.AlgorithmSHA1,
secretSize: 20,
}
for _,opt := range options {
opt(factory)
}
return factory
}
func (f *totpFactory) Generate(ttl time.Duration) (ret TOTP, err error) {
secret, err := f.generateSecret()
if err != nil {
return
}
return f.Refresh(secret, ttl)
}
func (f *totpFactory) Refresh(secret string, ttl time.Duration) (ret TOTP, err error) {
if ttl < time.Second {
return ret, fmt.Errorf("ttl should be greater or equals to 1 seconds")
}
now := time.Now()
ttl = ttl.Round(time.Second)
passcode, err := totp.GenerateCodeCustom(secret, now, totp.ValidateOpts{
Period: uint(ttl.Seconds()),
Skew: f.skew,
Digits: f.digits,
Algorithm: f.alg,
})
if err != nil {
return
}
ret = TOTP{
Passcode: passcode,
Secret: secret,
TTL: ttl,
Expire: now.Add(ttl),
}
return
}
func (f *totpFactory) Validate(value TOTP) (valid bool, err error) {
if value.TTL < time.Second {
return false, fmt.Errorf("ttl should be greater or equals to 1 seconds")
}
return totp.ValidateCustom(value.Passcode, value.Secret, time.Now(), totp.ValidateOpts{
Period: uint(value.TTL.Round(time.Second).Seconds()),
Skew: f.skew,
Digits: f.digits,
Algorithm: f.alg,
})
}
func (f *totpFactory) generateSecret() (string, error) {
secret := make([]byte, f.secretSize)
_, err := rand.Reader.Read(secret)
if err != nil {
return "", err
}
return b32NoPadding.EncodeToString(secret), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"net/http"
"strings"
"time"
)
/***********************
Session
************************/
const SessionPropertiesPrefix = "security.session"
type SessionProperties struct {
Cookie CookieProperties
IdleTimeout utils.Duration `json:"idle-timeout"`
AbsoluteTimeout utils.Duration `json:"absolute-timeout"`
MaxConcurrentSession int `json:"max-concurrent-sessions"`
DbIndex int `json:"db-index"`
}
type CookieProperties struct {
Domain string `json:"domain"`
MaxAge int `json:"max-age"`
Secure bool `json:"secure"`
HttpOnly bool `json:"http-only"`
SameSiteString string `json:"same-site"`
Path string `json:"path"`
}
func (cp CookieProperties) SameSite() http.SameSite {
switch strings.ToLower(cp.SameSiteString) {
case "lax":
return http.SameSiteLaxMode
case "strict":
return http.SameSiteStrictMode
case "none":
return http.SameSiteNoneMode
default:
return http.SameSiteDefaultMode
}
}
// NewSessionProperties create a SessionProperties with default values
func NewSessionProperties() *SessionProperties {
return &SessionProperties{
Cookie: CookieProperties{
HttpOnly: true,
Path: "/",
},
IdleTimeout: utils.Duration(900 * time.Second),
AbsoluteTimeout: utils.Duration(1800 * time.Second),
MaxConcurrentSession: 0, //unlimited
}
}
// BindSessionProperties create and bind SessionProperties, with a optional prefix
func BindSessionProperties(ctx *bootstrap.ApplicationContext) SessionProperties {
props := NewSessionProperties()
if err := ctx.Config().Bind(props, SessionPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind SessionProperties"))
}
return *props
}
const TimeoutPropertiesPrefix = "security.timeout-support"
type TimeoutSupportProperties struct {
DbIndex int `json:"db-index"`
}
func NewTimeoutSupportProperties() *TimeoutSupportProperties {
return &TimeoutSupportProperties{}
}
func BindTimeoutSupportProperties(ctx *bootstrap.ApplicationContext) TimeoutSupportProperties {
props := NewTimeoutSupportProperties()
if err := ctx.Config().Bind(props, TimeoutPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind TimeoutSupportProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package redirect
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"net/http"
urlutils "net/url"
"path"
)
const (
FlashKeyPreviousError = "error"
FlashKeyPreviousStatusCode = "status"
)
// RedirectHandler implements multiple interface for authentication and error handling strategies
//goland:noinspection GoNameStartsWithPackageName
type RedirectHandler struct {
sc int
location string
ignoreCtxPath bool
}
func NewRedirectWithRelativePath(path string, ignoreCtxPath bool) *RedirectHandler {
url, err := urlutils.Parse(path)
if err != nil {
panic(err)
}
return &RedirectHandler{
sc: 302,
location: url.String(),
ignoreCtxPath: ignoreCtxPath,
}
}
func NewRedirectWithURL(urlStr string) *RedirectHandler {
url, err := urlutils.Parse(urlStr)
if err != nil {
panic(err)
}
return &RedirectHandler{
sc: 302,
location: url.String(),
}
}
// Commence implements security.AuthenticationEntryPoint
func (ep *RedirectHandler) Commence(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
ep.doRedirect(c, r, rw, map[string]interface{}{
FlashKeyPreviousError: err,
FlashKeyPreviousStatusCode: http.StatusUnauthorized,
})
}
// HandleAccessDenied implements security.AccessDeniedHandler
func (ep *RedirectHandler) HandleAccessDenied(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
ep.doRedirect(c, r, rw, map[string]interface{}{
FlashKeyPreviousError: err,
FlashKeyPreviousStatusCode: http.StatusForbidden,
})
}
// HandleAuthenticationSuccess implements security.AuthenticationSuccessHandler
func (ep *RedirectHandler) HandleAuthenticationSuccess(c context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
ep.doRedirect(c, r, rw, nil)
}
// HandleAuthenticationError implements security.AuthenticationErrorHandler
func (ep *RedirectHandler) HandleAuthenticationError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
ep.doRedirect(c, r, rw, map[string]interface{}{
FlashKeyPreviousError: err,
FlashKeyPreviousStatusCode: http.StatusUnauthorized,
})
}
func (ep *RedirectHandler) doRedirect(c context.Context, r *http.Request, rw http.ResponseWriter, flashes map[string]interface{}) {
if grw, ok := rw.(gin.ResponseWriter); ok && grw.Written() {
return
}
// save flashes
if flashes != nil && len(flashes) != 0 {
s := session.Get(c)
if s != nil {
for k, v := range flashes {
s.AddFlash(v, k)
}
}
}
location, _ := urlutils.Parse(ep.location)
if !location.IsAbs() {
// relative path was used, try to add context path
contextPath := web.ContextPath(c)
if !ep.ignoreCtxPath {
location.Path = path.Join(contextPath, location.Path)
}
}
// redirect
http.Redirect(rw, r, location.String(), ep.sc)
_, _ = rw.Write([]byte{})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package request_cache
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/security/session/common"
"github.com/cisco-open/go-lanai/pkg/web"
)
var (
FeatureId = security.FeatureId("request_cache", security.FeatureOrderRequestCache)
)
type Feature struct {
sessionName string
}
func (f *Feature) Identifier() security.FeatureIdentifier {
return FeatureId
}
func (f *Feature) SessionName(sessionName string) *Feature {
f.sessionName = sessionName
return f
}
// Configure Standard security.Feature entrypoint
func Configure(ws security.WebSecurity) *Feature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*Feature)
}
panic(fmt.Errorf("unable to configure session: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// New Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *Feature {
return &Feature{
sessionName: common.DefaultName,
}
}
type Configurer struct {
//cached request preprocessor
cachedRequestPreProcessor *CachedRequestPreProcessor
}
func newConfigurer() *Configurer {
return &Configurer{}
}
func (sc *Configurer) Apply(feature security.Feature, ws security.WebSecurity) error {
f := feature.(*Feature)
if len(f.sessionName) == 0 {
f.sessionName = common.DefaultName
}
if sc.cachedRequestPreProcessor == nil {
if store, ok := ws.Shared(security.WSSharedKeySessionStore).(session.Store); ok {
p := newCachedRequestPreProcessor(f.sessionName, store)
sc.cachedRequestPreProcessor = p
if ws.Shared(security.WSSharedKeyRequestPreProcessors) == nil {
ps := map[web.RequestPreProcessorName]web.RequestPreProcessor{p.Name():p}
_ = ws.AddShared(security.WSSharedKeyRequestPreProcessors, ps)
} else if ps, ok := ws.Shared(security.WSSharedKeyRequestPreProcessors).(map[web.RequestPreProcessorName]web.RequestPreProcessor); ok {
if _, exists := ps[p.name]; !exists {
ps[p.Name()] = p
}
}
}
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package request_cache
import (
"encoding/gob"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "request_cache",
Precedence: security.MinSecurityPrecedence + 20, //after session
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
GobRegister()
}
func GobRegister() {
gob.Register((*CachedRequest)(nil))
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
configurer := newConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package request_cache
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"reflect"
)
const SessionKeyCachedRequest = "CachedRequest"
type CachedRequest struct {
Method string
URL *url.URL
Header http.Header
Form url.Values
PostForm url.Values
Host string
}
func SaveRequest(ctx context.Context) {
gc := web.GinContext(ctx)
if gc == nil {
return
}
s := session.Get(ctx)
// we don't know if other components have already parsed the form.
// if other components have already parsed the form, then the body is already read, so if we read it again we'll just get ""
// therefore we call parseForm to make sure it's read into the form field, and we serialize the form field ourselves.
_ = gc.Request.ParseForm()
cached := &CachedRequest{
Method: gc.Request.Method,
URL: gc.Request.URL,
Host: gc.Request.Host,
PostForm: gc.Request.PostForm,
Form: gc.Request.Form,
Header: gc.Request.Header,
}
s.Set(SessionKeyCachedRequest, cached)
}
func GetCachedRequest(ctx context.Context) *CachedRequest {
s := session.Get(ctx)
cached, _ := s.Get(SessionKeyCachedRequest).(*CachedRequest)
return cached
}
func RemoveCachedRequest(ctx *gin.Context) {
s := session.Get(ctx)
s.Delete(SessionKeyCachedRequest)
}
// CachedRequestPreProcessor is designed to be used by code outside of the security package.
// Implements the web.RequestCacheAccessor interface
type CachedRequestPreProcessor struct {
sessionName string
store session.Store
name web.RequestPreProcessorName
}
func newCachedRequestPreProcessor(sessionName string, store session.Store) *CachedRequestPreProcessor {
return &CachedRequestPreProcessor{
sessionName: sessionName,
store: store,
name: "CachedRequestPreProcessor",
}
}
func (p *CachedRequestPreProcessor) Name() web.RequestPreProcessorName {
return p.name
}
func (p *CachedRequestPreProcessor) Process(r *http.Request) error {
if cookie, err := r.Cookie(p.sessionName); err == nil {
id := cookie.Value
if s, err := p.store.WithContext(r.Context()).Get(id, p.sessionName); err == nil {
cached, ok := s.Get(SessionKeyCachedRequest).(*CachedRequest)
if ok && cached != nil && requestMatches(r, cached) {
s.Delete(SessionKeyCachedRequest)
err := p.store.WithContext(r.Context()).Save(s)
if err != nil {
return err
}
r.Method = cached.Method
//because popMatchRequest only matches on GET, so incoming request body is always http.nobody
//therefore we set the form and post form directly.
//multi part form (used for file uploads) are not supported - if original request was multi part form, it's not cached.
//trailer headers are also not supported - if original request has trailer, it's not cached.
r.Form = cached.Form
r.PostForm = cached.PostForm
//get all the headers from the cached request except the cookie header
if cached.Header != nil {
cookie := r.Header["Cookie"]
r.Header = cached.Header
r.Header["Cookie"] = cookie
}
return nil
}
}
}
return nil
}
func requestMatches(r *http.Request, cached *CachedRequest) bool {
// Only support matching incoming GET command, because we will only issue redirect after auth success.
if r.Method != "GET" {
return false
}
return reflect.DeepEqual(r.URL, cached.URL) && r.Host == cached.Host
}
func NewSavedRequestAuthenticationSuccessHandler(fallback security.AuthenticationSuccessHandler, condition func(from, to security.Authentication) bool) security.AuthenticationSuccessHandler {
if condition == nil {
condition = security.IsBeingAuthenticated
}
return &SavedRequestAuthenticationSuccessHandler{
condition: condition,
fallback: fallback,
}
}
type SavedRequestAuthenticationSuccessHandler struct {
condition func(from, to security.Authentication) bool
fallback security.AuthenticationSuccessHandler
}
func (h *SavedRequestAuthenticationSuccessHandler) HandleAuthenticationSuccess(c context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
if !h.condition(from, to) {
return
}
cached := GetCachedRequest(c)
if cached != nil {
http.Redirect(rw, r, cached.URL.RequestURI(), 302)
_, _ = rw.Write([]byte{})
return
}
h.fallback.HandleAuthenticationSuccess(c, r, rw, from, to)
}
type SaveRequestEntryPoint struct {
delegate security.AuthenticationEntryPoint
saveRequestMatcher web.RequestMatcher
}
func NewSaveRequestEntryPoint(delegate security.AuthenticationEntryPoint) *SaveRequestEntryPoint {
notFavicon := matcher.NotRequest(matcher.RequestWithPattern("/**/favicon.*"))
notXMLHttpRequest := matcher.NotRequest(matcher.RequestWithHeader("X-Requested-With", "XMLHttpRequest", false))
notTrailer := matcher.NotRequest(matcher.RequestHasHeader("Trailer"))
notMultiPart := matcher.NotRequest(matcher.RequestWithHeader("Content-Type", "multipart/form-data", true))
notCsrf := matcher.NotRequest(matcher.RequestHasHeader(security.CsrfHeaderName).Or(matcher.RequestHasPostForm(security.CsrfParamName)))
saveRequestMatcher := notFavicon.And(notXMLHttpRequest).And(notTrailer).And(notMultiPart).And(notCsrf)
return &SaveRequestEntryPoint{
delegate,
saveRequestMatcher,
}
}
func (s *SaveRequestEntryPoint) Commence(c context.Context, r *http.Request, w http.ResponseWriter, e error) {
match, err := s.saveRequestMatcher.MatchesWithContext(c, r)
if match && err == nil {
SaveRequest(c)
}
s.delegate.Commence(c, r, w, e)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
dsig "github.com/russellhaering/goxmldsig"
"net/url"
)
type samlConfigurer struct {
properties samlctx.SamlProperties
samlClientStore samlctx.SamlClientStore
}
func (c *samlConfigurer) getIdentityProviderConfiguration(f *Feature) *Options {
cert, err := cryptoutils.LoadCert(c.properties.CertificateFile)
if err != nil {
panic(security.NewInternalError("cannot load certificate from file", err))
}
if len(cert) > 1 {
logger.Warnf("multiple certificate found, using first one")
}
key, err := cryptoutils.LoadPrivateKey(c.properties.KeyFile, c.properties.KeyPassword)
if err != nil {
panic(security.NewInternalError("cannot load private key from file", err))
}
rootURL, err := f.issuer.BuildUrl()
if err != nil {
panic(security.NewInternalError("cannot get issuer's base URL", err))
}
var signingMethod string
switch f.signingMethod {
case dsig.RSASHA1SignatureMethod:
fallthrough
case dsig.RSASHA256SignatureMethod:
fallthrough
case dsig.RSASHA512SignatureMethod:
signingMethod = f.signingMethod
default:
signingMethod = dsig.RSASHA1SignatureMethod
}
return &Options{
Key: key,
Cert: cert[0],
//usually this is the metadata url, but to keep consistent with existing implementation, we just use the context path
EntityIdUrl: *rootURL,
SsoUrl: *rootURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("%s%s", rootURL.Path, f.ssoLocation.Path),
RawQuery: f.ssoLocation.RawQuery,
}),
SloUrl: *rootURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("%s%s", rootURL.Path, f.logoutUrl),
}),
SigningMethod: signingMethod,
serviceProviderManager: c.samlClientStore,
}
}
func (c *samlConfigurer) metadataMiddleware(f *Feature) *MetadataMiddleware {
opts := c.getIdentityProviderConfiguration(f)
return NewMetadataMiddleware(opts, c.samlClientStore)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"github.com/cisco-open/go-lanai/pkg/utils"
)
type DefaultSamlClient struct {
SamlSpDetails
TenantRestrictions utils.StringSet
TenantRestrictionType string
}
func (c DefaultSamlClient) ShouldMetadataRequireSignature() bool {
return c.MetadataRequireSignature
}
func (c DefaultSamlClient) ShouldMetadataTrustCheck() bool {
return c.MetadataTrustCheck
}
func (c DefaultSamlClient) GetMetadataTrustedKeys() []string {
return c.MetadataTrustedKeys
}
func (c DefaultSamlClient) GetEntityId() string {
return c.EntityId
}
func (c DefaultSamlClient) GetMetadataSource() string {
return c.MetadataSource
}
func (c DefaultSamlClient) ShouldSkipAssertionEncryption() bool {
return c.SkipAssertionEncryption
}
func (c DefaultSamlClient) ShouldSkipAuthRequestSignatureVerification() bool {
return c.SkipAuthRequestSignatureVerification
}
func (c DefaultSamlClient) GetTenantRestrictions() utils.StringSet {
return c.TenantRestrictions
}
func (c DefaultSamlClient) GetTenantRestrictionType() string {
return c.TenantRestrictionType
}
type SamlSpDetails struct {
EntityId string
MetadataSource string
SkipAssertionEncryption bool
SkipAuthRequestSignatureVerification bool
MetadataRequireSignature bool
MetadataTrustCheck bool
MetadataTrustedKeys []string
//currently the implementation is metaiop profile. this field is reserved for future use
// https://docs.spring.io/autorepo/docs/spring-security-saml/1.0.x-SNAPSHOT/reference/htmlsingle/#configuration-security-profiles-pkix
SecurityProfile string
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
"github.com/crewjam/saml"
"net/http"
)
const CtxKeySamlAuthnRequest = "kSamlAuthnRequest"
type SamlErrorHandler struct {}
func NewSamlErrorHandler() *SamlErrorHandler {
return &SamlErrorHandler{}
}
// HandleError
/**
Handles error as saml response when possible.
Otherwise let the error handling handle it
See http://docs.oasis-open.org/security/saml/v2.0/saml-profiles-2.0-os.pdf 4.1.3.5
*/
//nolint:errorlint
func (h *SamlErrorHandler) HandleError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
err = h.findCause(err)
if !errors.Is(err, security.ErrorTypeSaml) {
return
}
switch {
case errors.Is(err, ErrorSubTypeSamlInternal):
writeErrorAsHtml(c, r, rw, err)
case errors.Is(err, ErrorSubTypeSamlSso):
h.handleSsoError(c, r, rw, err)
case errors.Is(err, ErrorSubTypeSamlSlo):
h.handleSloError(c, r, rw, err)
}
}
// findCause returns nested error if it's caused by SAML error, otherwise return error itself
//nolint:errorlint
func (h *SamlErrorHandler) findCause(err error) error {
e := err
for ;!errors.Is(e, security.ErrorTypeSaml); {
nested, ok := e.(errorutils.NestedError)
if !ok {
return err
}
e = nested.Cause()
}
return e
}
//nolint:errorlint
func (h *SamlErrorHandler) handleSsoError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
authRequest, ok := c.Value(CtxKeySamlAuthnRequest).(*saml.IdpAuthnRequest)
if !ok {
writeErrorAsHtml(c, r, rw, err)
}
code := saml.StatusResponder
message := ""
if translator, ok := err.(SamlSsoErrorTranslator); ok { //all the saml sub types should implement the translator API
code = translator.TranslateErrorCode()
message = translator.TranslateErrorMessage()
}
respErr := MakeErrorResponse(authRequest, code, message)
if respErr != nil {
writeErrorAsHtml(c, r, rw, NewSamlInternalError("cannot create response", respErr))
}
writeErr := authRequest.WriteResponse(rw)
if writeErr != nil {
writeErrorAsHtml(c, r, rw, NewSamlInternalError("cannot write response", writeErr))
}
}
//nolint:errorlint
func (h *SamlErrorHandler) handleSloError(c context.Context, r *http.Request, rw http.ResponseWriter, err error) {
sloRequest, ok := c.Value(ctxKeySloRequest).(*SamlLogoutRequest)
if !ok {
writeErrorAsHtml(c, r, rw, err)
return
}
code := saml.StatusAuthnFailed
message := err.Error()
if translator, ok := err.(SamlSsoErrorTranslator); ok { //all the saml sub types should implement the translator API
code = translator.TranslateErrorCode()
message = translator.TranslateErrorMessage()
}
switch {
case errors.Is(err, ErrorSamlSloRequester):
// requester error, means requester is not validated, we display errors as HTML
writeErrorAsHtml(c, r, rw, err)
return
}
resp, e := MakeLogoutResponse(sloRequest, code, message)
if e != nil {
msg := fmt.Sprintf("unable to create logout error response with code [%s]: %s. Reason: %v", code, message, e)
writeErrorAsHtml(c, r, rw, NewSamlInternalError(msg, e))
return
}
sloRequest.Response = resp
if e := sloRequest.WriteResponse(rw); e != nil {
msg := fmt.Sprintf("unable to send logout error response with code [%s]: %s. Reason: %v", code, message, e)
writeErrorAsHtml(c, r, rw, NewSamlInternalError(msg, e))
}
}
func writeErrorAsHtml(c context.Context, _ *http.Request, rw http.ResponseWriter, err error) {
code := http.StatusInternalServerError
//nolint:errorlint
if translator, ok := err.(SamlSsoErrorTranslator); ok { //all the saml errors should implement this interface
code = translator.TranslateHttpStatusCode()
}
security.WriteErrorAsHtml(c, rw, code, err)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
"github.com/crewjam/saml"
"net/http"
)
//errors maps to the status code described in section 3.2.2 of http://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf
const (
_ = iota
// ErrorSubTypeCodeSamlSso non-programming error that can occur during SAML web sso flow. These errors will be returned to the requester
// as a status code when possible
ErrorSubTypeCodeSamlSso = security.ErrorTypeCodeSaml + iota<<errorutils.ErrorSubTypeOffset
// ErrorSubTypeCodeSamlSlo non-programming error that can occur during SAML SLO flow
ErrorSubTypeCodeSamlSlo
// ErrorSubTypeCodeSamlInternal programming error, these will be displayed on an error page
// so that we can fix the error on our end.
ErrorSubTypeCodeSamlInternal
)
// ErrorSubTypeCodeSamlSso
const (
_ = ErrorSubTypeCodeSamlSso + iota
ErrorCodeSamlSsoRequester
ErrorCodeSamlSsoResponder
ErrorCodeSamlSsoRequestVersionMismatch
)
// ErrorSubTypeCodeSamlSlo
const (
_ = ErrorSubTypeCodeSamlSlo + iota
ErrorCodeSamlSloRequester
ErrorCodeSamlSloResponder
)
// ErrorSubTypeCodeSamlInternal
const (
_ = ErrorSubTypeCodeSamlInternal + iota
ErrorCodeSamlInternalGeneral
)
var (
ErrorSubTypeSamlSso = security.NewErrorSubType(ErrorSubTypeCodeSamlSso, errors.New("error sub-type: sso"))
ErrorSubTypeSamlSlo = security.NewErrorSubType(ErrorSubTypeCodeSamlSlo, errors.New("error sub-type: slo"))
ErrorSubTypeSamlInternal = security.NewErrorSubType(ErrorSubTypeCodeSamlInternal, errors.New("error sub-type: internal"))
// ErrorSamlSloRequester requester errors are displayed as a HTML page
ErrorSamlSloRequester = security.NewCodedError(ErrorCodeSamlSloRequester, "SLO requester error")
// ErrorSamlSloResponder responder errors are communicated back to SP via bindings
ErrorSamlSloResponder = security.NewCodedError(ErrorCodeSamlSloResponder, "SLO responder error")
)
type SamlSsoErrorTranslator interface {
error
TranslateErrorCode() string
TranslateErrorMessage() string
TranslateHttpStatusCode() int
}
type SamlError struct {
security.CodedError
EC string // saml error code
SC int // status code
}
func NewSamlError(code int64, e interface{}, samlErrorCode string, httpStatusCode int, causes ...interface{}) *SamlError {
embedded := security.NewCodedError(code, e, causes...)
return &SamlError{
CodedError: *embedded,
EC: samlErrorCode,
SC: httpStatusCode,
}
}
func (s *SamlError) TranslateErrorCode() string {
return s.EC
}
func (s *SamlError) TranslateErrorMessage() string {
return s.Error()
}
func (s *SamlError) TranslateHttpStatusCode() int {
return s.SC
}
func NewSamlInternalError(text string, causes ...interface{}) error {
return NewSamlError(ErrorCodeSamlInternalGeneral, errors.New(text), "", http.StatusInternalServerError, causes...)
}
func NewSamlRequesterError(text string, causes ...interface{}) error {
return NewSamlError(ErrorCodeSamlSsoRequester, errors.New(text), saml.StatusRequester, http.StatusBadRequest, causes...)
}
func NewSamlResponderError(text string, causes ...interface{}) error {
return NewSamlError(ErrorCodeSamlSsoResponder, errors.New(text), saml.StatusResponder, http.StatusInternalServerError, causes...)
}
func NewSamlRequestVersionMismatch(text string, causes ...interface{}) error {
return NewSamlError(ErrorCodeSamlSsoRequestVersionMismatch, errors.New(text), saml.StatusVersionMismatch, http.StatusConflict, causes...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web"
"net/url"
)
var (
FeatureId = security.FeatureId("SamlAuthorizeEndpoint", security.FeatureOrderSamlAuthorizeEndpoint)
SloFeatureId = security.FeatureId("SamlSLOEndpoint", security.FeatureOrderSamlLogout)
)
type Feature struct {
id security.FeatureIdentifier
ssoCondition web.RequestMatcher
ssoLocation *url.URL
signingMethod string
metadataPath string
issuer security.Issuer
logoutUrl string
}
// New Standard security.Feature entrypoint for authorization, DSL style. Used with security.WebSecurity
func New() *Feature {
return &Feature{
id: FeatureId,
}
}
// NewLogout Standard security.Feature entrypoint for single-logout, DSL style. Used with security.WebSecurity
func NewLogout() *Feature {
return &Feature{
id: SloFeatureId,
}
}
func (f *Feature) Identifier() security.FeatureIdentifier {
return f.id
}
func (f *Feature) SsoCondition(condition web.RequestMatcher) *Feature {
f.ssoCondition = condition
return f
}
func (f *Feature) SsoLocation(location *url.URL) *Feature {
f.ssoLocation = location
return f
}
func (f *Feature) MetadataPath(path string) *Feature {
f.metadataPath = path
return f
}
func (f *Feature) Issuer(issuer security.Issuer) *Feature {
f.issuer = issuer
return f
}
func (f *Feature) SigningMethod(signatureMethod string) *Feature {
f.signingMethod = signatureMethod
return f
}
// EnableSLO when logoutUrl is set, SLO Request handling is added to logout.Feature.
// SLO feature cannot work properly if this value mismatches the logout URL
func (f *Feature) EnableSLO(logoutUrl string) *Feature {
f.logoutUrl = logoutUrl
return f
}
func Configure(ws security.WebSecurity) *Feature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*Feature)
}
panic(fmt.Errorf("unable to configure saml authserver: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
func ConfigureLogout(ws security.WebSecurity) *Feature {
feature := NewLogout()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*Feature)
}
panic(fmt.Errorf("unable to configure saml authserver: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"os"
)
type loggerAdapter struct {
delegate log.Logger
}
func newLoggerAdaptor(l log.Logger) *loggerAdapter{
return &loggerAdapter{
delegate: l,
}
}
func (s *loggerAdapter) Printf(format string, v ...interface{}) {
s.delegate.Infof(format, v...)
}
func (s *loggerAdapter) Print(v ...interface{}) {
s.delegate.Info(fmt.Sprint(v...))
}
func (s *loggerAdapter) Println(v ...interface{}) {
s.Print(v...)
}
func (s *loggerAdapter) Fatal(v ...interface{}) {
s.delegate.Error(fmt.Sprint(v...))
os.Exit(1)
}
func (s *loggerAdapter) Fatalf(format string, v ...interface{}) {
s.delegate.Errorf(format, v...)
os.Exit(1)
}
func (s *loggerAdapter) Fatalln(v ...interface{}) {
s.Fatal(v...)
}
func (s *loggerAdapter) Panic(v ...interface{}) {
s.delegate.Error(fmt.Sprint(v...))
panic(fmt.Sprint(v...))
}
func (s *loggerAdapter) Panicf(format string, v ...interface{}) {
s.delegate.Errorf(format, v...)
panic(fmt.Sprintf(format, v...))
}
func (s *loggerAdapter) Panicln(v ...interface{}) {
s.Panic(v...)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"context"
"crypto/x509"
"errors"
"fmt"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
"net/http"
"reflect"
"sync"
)
type SpMetadataManager struct {
//for fetching idp metadata
httpClient *http.Client
//entityId to descriptor
cache map[string]*saml.EntityDescriptor
processed map[string]SamlSpDetails
cacheMutex sync.RWMutex
}
func (m *SpMetadataManager) GetServiceProvider(serviceProviderID string) (SamlSpDetails, *saml.EntityDescriptor, error) {
m.cacheMutex.RLock()
defer m.cacheMutex.RUnlock()
for k, v := range m.cache {
if k == serviceProviderID {
return m.processed[k], v, nil
}
}
return SamlSpDetails{}, nil, errors.New(fmt.Sprintf("service provider metadata for %s not found", serviceProviderID))
}
func (m *SpMetadataManager) RefreshCache(ctx context.Context, clients []samlctx.SamlClient) {
m.cacheMutex.RLock()
remove, refresh := m.compareWithCache(clients)
m.cacheMutex.RUnlock()
//nothing changed, just return
if len(refresh) == 0 && len(remove) == 0{
return
}
m.cacheMutex.Lock()
defer m.cacheMutex.Unlock()
//check again because another process may have got into the write lock first and updated cache
remove, refresh = m.compareWithCache(clients)
//if the cache was updated by another process, then just return
if len(refresh) == 0 && len(remove) == 0 {
return
}
resolved := m.resolveMetadata(ctx, refresh)
for entityId, doRemove := range remove {
if doRemove {
delete(m.cache, entityId)
delete(m.processed, entityId)
}
}
for _, details := range refresh {
if spDescriptor, ok := resolved[details.EntityId]; ok {
m.cache[details.EntityId] = spDescriptor
m.processed[details.EntityId] = details
}
}
}
func (m *SpMetadataManager) compareWithCache(clients []samlctx.SamlClient) (remove map[string]bool, refresh []SamlSpDetails) {
keep := make(map[string]bool)
remove = make(map[string]bool)
for _, c := range clients {
var details SamlSpDetails
if defaultClient, ok := c.(DefaultSamlClient); ok {
details = defaultClient.SamlSpDetails
} else {
details = SamlSpDetails{
EntityId:c.GetEntityId(),
MetadataSource: c.GetMetadataSource(),
SkipAssertionEncryption: c.ShouldSkipAssertionEncryption(),
SkipAuthRequestSignatureVerification: c.ShouldSkipAuthRequestSignatureVerification(),
MetadataRequireSignature: c.ShouldMetadataRequireSignature(),
MetadataTrustCheck: c.ShouldMetadataTrustCheck(),
MetadataTrustedKeys: c.GetMetadataTrustedKeys(),
}
}
if _, ok := m.cache[details.EntityId]; !ok {
refresh = append(refresh, details)
} else {
processed := m.processed[details.EntityId]
if !reflect.DeepEqual(processed, details) {
refresh = append(refresh, details)
} else {
keep[details.EntityId] = true
}
}
}
for entityId := range m.cache {
if _, ok := keep[entityId]; !ok {
remove[entityId] = true
}
}
return remove, refresh
}
func (m *SpMetadataManager) resolveMetadata(ctx context.Context, refresh []SamlSpDetails) (resolved map[string]*saml.EntityDescriptor) {
resolved = make(map[string]*saml.EntityDescriptor)
for _, details := range refresh {
spDescriptor, data, err := samlutils.ResolveMetadata(ctx, details.MetadataSource, samlutils.WithHttpClient(m.httpClient))
if err == nil {
if details.MetadataRequireSignature && spDescriptor.Signature == nil{
logger.WithContext(ctx).Warnf("sp metadata rejected because it is not signed")
continue
}
if details.MetadataTrustCheck {
var allCerts []*x509.Certificate
for _, keyLoc := range details.MetadataTrustedKeys {
certs, err := cryptoutils.LoadCert(keyLoc)
if err == nil {
allCerts = append(allCerts, certs...)
}
}
err = samlutils.VerifySignature(samlutils.MetadataSignature(data, allCerts...))
if err != nil {
logger.WithContext(ctx).Warnf("sp metadata rejected because it's signature cannot be verified")
continue
}
}
resolved[details.EntityId] = spDescriptor
} else {
logger.WithContext(ctx).Warnf("could not resolve idp metadata", "details", details)
}
}
return resolved
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"crypto"
"crypto/x509"
"encoding/xml"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/crewjam/saml"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"sort"
)
type Options struct {
Key crypto.PrivateKey
Cert *x509.Certificate
EntityIdUrl url.URL
SsoUrl url.URL
SloUrl url.URL
SigningMethod string
serviceProviderManager samlctx.SamlClientStore
}
type MetadataMiddleware struct {
samlClientStore samlctx.SamlClientStore // used to load the saml clients
spMetadataManager *SpMetadataManager // manages the resolved service provider metadata
idp *saml.IdentityProvider
}
func NewMetadataMiddleware(opts *Options, samlClientStore samlctx.SamlClientStore) *MetadataMiddleware {
spDescriptorManager := &SpMetadataManager{
cache: make(map[string]*saml.EntityDescriptor),
processed: make(map[string]SamlSpDetails),
httpClient: http.DefaultClient,
}
idp := &saml.IdentityProvider{
Key: opts.Key,
Logger: newLoggerAdaptor(logger),
Certificate: opts.Cert,
//since we have our own middleware implementation, this value here only serves the purpose of defining the entity id.
MetadataURL: opts.EntityIdUrl,
SSOURL: opts.SsoUrl,
LogoutURL: opts.SloUrl,
SignatureMethod: opts.SigningMethod,
}
mw := &MetadataMiddleware{
idp: idp,
samlClientStore: samlClientStore,
spMetadataManager: spDescriptorManager,
}
return mw
}
func (mw *MetadataMiddleware) RefreshMetadataHandler(condition web.RequestMatcher) gin.HandlerFunc {
return func(c *gin.Context) {
if matches, err := condition.MatchesWithContext(c.Request.Context(), c.Request); !matches || err != nil {
return
}
if clients, e := mw.samlClientStore.GetAllSamlClient(c.Request.Context()); e == nil {
mw.spMetadataManager.RefreshCache(c, clients)
}
}
}
func (mw *MetadataMiddleware) MetadataHandlerFunc() gin.HandlerFunc {
return func(c *gin.Context) {
metadata := mw.idp.Metadata()
sort.SliceStable(metadata.IDPSSODescriptors[0].SingleSignOnServices, func(i, j int) bool {
return metadata.IDPSSODescriptors[0].SingleSignOnServices[i].Binding < metadata.IDPSSODescriptors[0].SingleSignOnServices[j].Binding
})
//We always want the authentication request to be signed
//But because this is not supported by the saml package, we set it here explicitly
var t = true
metadata.IDPSSODescriptors[0].WantAuthnRequestsSigned = &t
// We also support POST Binding of logout request, which is not added by crewjam/saml package
if mw.idp.LogoutURL.String() != "" {
metadata.IDPSSODescriptors[0].SSODescriptor.SingleLogoutServices = []saml.Endpoint{
{ Binding: saml.HTTPRedirectBinding, Location: mw.idp.LogoutURL.String() },
{ Binding: saml.HTTPPostBinding, Location: mw.idp.LogoutURL.String() },
}
}
// send the response
w := c.Writer
buf, _ := xml.MarshalIndent(metadata, "", " ")
w.Header().Set("Content-Type", "application/samlmetadata+xml")
w.Header().Set("Content-Disposition", "attachment; filename=metadata.xml")
_, _ = w.Write(buf)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "saml auth - authorize",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Invoke(register),
},
}
var logger = log.New("SAML.SSO")
func Use() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
Properties samlctx.SamlProperties
ServerProperties web.ServerProperties
ServiceProviderManager samlctx.SamlClientStore `optional:"true"`
AccountStore security.AccountStore `optional:"true"`
AttributeGenerator AttributeGenerator `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil {
authConfigurer := newSamlAuthorizeEndpointConfigurer(di.Properties,
di.ServiceProviderManager, di.AccountStore,
di.AttributeGenerator)
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, authConfigurer)
sloConfigurer := newSamlLogoutEndpointConfigurer(di.Properties, di.ServiceProviderManager)
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(SloFeatureId, sloConfigurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/logout"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"net/http"
)
type SamlLogoutEndpointConfigurer struct {
samlConfigurer
}
func newSamlLogoutEndpointConfigurer(properties samlctx.SamlProperties,
samlClientStore samlctx.SamlClientStore) *SamlLogoutEndpointConfigurer {
return &SamlLogoutEndpointConfigurer{
samlConfigurer: samlConfigurer{
properties: properties,
samlClientStore: samlClientStore,
},
}
}
func (c *SamlLogoutEndpointConfigurer) Apply(feature security.Feature, ws security.WebSecurity) (err error) {
f := feature.(*Feature)
if len(f.logoutUrl) == 0 {
// not enabled
return
}
metaMw := c.metadataMiddleware(f)
mw := NewSamlSingleLogoutMiddleware(metaMw)
ws.
Add(middleware.NewBuilder("Saml Service Provider Refresh").
ApplyTo(matcher.RouteWithPattern(f.logoutUrl, http.MethodGet, http.MethodPost)).
Order(security.MWOrderSAMLMetadataRefresh).
Use(mw.RefreshMetadataHandler(mw.SLOCondition())),
)
logout.Configure(ws).
AddLogoutHandler(mw).
AddSuccessHandler(mw).
AddErrorHandler(mw).
AddEntryPoint(mw)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/crewjam/saml"
"github.com/gin-gonic/gin"
"net/http"
)
var (
ctxKeySloRequest = "slo.request"
supportedLogoutResponseBindings = utils.NewStringSet(saml.HTTPPostBinding)
)
// SamlSingleLogoutMiddleware is a
// 1. logout.LogoutHandler
// 2. logout.ConditionalLogoutHandler
// 3. security.AuthenticationSuccessHandler
// 4. security.AuthenticationErrorHandler
// 5. security.AuthenticationEntryPoint
// focusing on validate SAML logout request and send back SAML LogoutResponse
type SamlSingleLogoutMiddleware struct {
*MetadataMiddleware
SamlErrorHandler
}
func NewSamlSingleLogoutMiddleware(metaMw *MetadataMiddleware) *SamlSingleLogoutMiddleware {
return &SamlSingleLogoutMiddleware{
MetadataMiddleware: metaMw,
}
}
func (mw *SamlSingleLogoutMiddleware) Order() int {
// always perform this first
return order.Highest
}
func (mw *SamlSingleLogoutMiddleware) SLOCondition() web.RequestMatcher {
return matcher.RequestHasForm(samlutils.HttpParamSAMLRequest)
}
// ShouldLogout is a logout.ConditionalLogoutHandler method that intercept SP initiated SAML request. Possible outcomes are:
// - no error returned if the logout is not SAML single logout (no SAMLRequest found)
// - no error returned if the logout is a valid SAMLLogoutRequest
// - ErrorSubTypeSamlSlo if SAMLLogoutRequest is found but invalid
func (mw *SamlSingleLogoutMiddleware) ShouldLogout(ctx context.Context, r *http.Request, _ http.ResponseWriter, _ security.Authentication) error {
gc := web.GinContext(ctx)
samlReq := mw.newSamlLogoutRequest(r)
var req saml.LogoutRequest
parsedReq := samlutils.ParseSAMLObject(gc, &req)
switch {
case parsedReq.Err != nil && len(parsedReq.Encoded) == 0:
// not SAML request, ignore
return nil
case parsedReq.Err != nil:
// Invalid SAML request, cancel with error
mw.populateContext(gc, samlReq)
return ErrorSamlSloRequester.WithMessage("unable to parse SAML SamlLogoutRequest: %v", parsedReq.Err)
}
samlReq.Binding = parsedReq.Binding
samlReq.Request = &req
samlReq.RequestBuffer = parsedReq.Decoded
if e := mw.preProcessLogoutRequest(gc, samlReq); e != nil {
return e
}
return nil
}
func (mw *SamlSingleLogoutMiddleware) HandleLogout(ctx context.Context, _ *http.Request, _ http.ResponseWriter, auth security.Authentication) error {
req, ok := ctx.Value(ctxKeySloRequest).(*SamlLogoutRequest)
if !ok {
return nil
}
if e := mw.processLogoutRequest(ctx, req, auth); e != nil {
return e
}
return mw.prepareSuccessSamlResponse(ctx, req)
}
func (mw *SamlSingleLogoutMiddleware) HandleAuthenticationSuccess(ctx context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
if req, ok := ctx.Value(ctxKeySloRequest).(*SamlLogoutRequest); ok {
// Note, in case of success, SAML Response is prepared, we just send it
if e := req.WriteResponse(rw); e != nil {
msg := fmt.Sprintf("unable to send logout success response: %v", e)
mw.HandleError(ctx, r, rw, NewSamlInternalError(msg, e))
}
}
return
}
func (mw *SamlSingleLogoutMiddleware) HandleAuthenticationError(ctx context.Context, r *http.Request, rw http.ResponseWriter, err error) {
mw.HandleError(ctx, r, rw, err)
}
func (mw *SamlSingleLogoutMiddleware) Commence(ctx context.Context, r *http.Request, rw http.ResponseWriter, err error) {
mw.HandleError(ctx, r, rw, err)
}
func (mw *SamlSingleLogoutMiddleware) newSamlLogoutRequest(r *http.Request) *SamlLogoutRequest {
return &SamlLogoutRequest{
HTTPRequest: r,
IDP: mw.idp,
}
}
func (mw *SamlSingleLogoutMiddleware) preProcessLogoutRequest(gc *gin.Context, req *SamlLogoutRequest) error {
defer mw.populateContext(gc, req)
// Note: we return Requester errors until we can determine the callback binding
if req.Request.Issuer == nil || len(req.Request.Issuer.Value) == 0 {
return ErrorSamlSloRequester.WithMessage("logout request missing Issuer")
}
// find the service provider metadata
spId := req.Request.Issuer.Value
spDetails, sp, e := mw.spMetadataManager.GetServiceProvider(spId)
if e != nil {
return ErrorSamlSloRequester.WithMessage("cannot find service provider metadata [%s]", spId)
}
// resolve SLO response endpoint
req.SPMeta = sp
if len(req.SPMeta.SPSSODescriptors) != 1 {
return ErrorSamlSloRequester.WithMessage("expected exactly one SP SSO descriptor in SP metadata [%s]", spId)
}
spDesc := req.SPMeta.SPSSODescriptors[0]
req.SPSSODescriptor = &spDesc
if e := mw.determineSloEndpoint(gc, req); e != nil {
return e
}
// validate request and relay state
req.RelayState = req.HTTPRequest.FormValue(samlutils.HttpParamRelayState)
if e := mw.validateLogoutRequest(gc, req, &spDetails); e != nil {
return e
}
return nil
}
func (mw *SamlSingleLogoutMiddleware) determineSloEndpoint(_ *gin.Context, req *SamlLogoutRequest) error {
// find first supported binding.
// Note: we only support POST binding for now, because of crewjam/saml 0.4.8 limitation
var found *saml.Endpoint
for i := range req.SPSSODescriptor.SingleLogoutServices {
ep := req.SPSSODescriptor.SingleLogoutServices[i]
if supportedLogoutResponseBindings.Has(ep.Binding) && len(ep.Location) != 0 {
found = &ep
break
}
}
if found == nil {
return ErrorSamlSloRequester.WithMessage("SAML SLO unable to find supported response bindings from SP. Should be one of %v", supportedLogoutResponseBindings.Values())
} else if len(found.ResponseLocation) == 0 {
found.ResponseLocation = found.Location
}
req.Callback = found
return nil
}
func (mw *SamlSingleLogoutMiddleware) validateLogoutRequest(_ *gin.Context, req *SamlLogoutRequest, spDetails *SamlSpDetails) error {
if !spDetails.SkipAuthRequestSignatureVerification {
if e := req.VerifySignature(); e != nil {
return ErrorSamlSloResponder.WithMessage("%s", e.Error())
}
}
return req.Validate()
}
func (mw *SamlSingleLogoutMiddleware) processLogoutRequest(_ context.Context, req *SamlLogoutRequest, auth security.Authentication) error {
if auth.State() < security.StatePrincipalKnown {
// no additional check are needed
return nil
}
nameID := req.Request.NameID
switch saml.NameIDFormat(nameID.Format) {
case saml.EmailAddressNameIDFormat:
fallthrough
case saml.TransientNameIDFormat:
fallthrough
case saml.PersistentNameIDFormat:
return ErrorSamlSloResponder.WithMessage("unsupported NameID format [%s]", nameID.Format)
default:
// we assume it's username
if username, e := security.GetUsername(auth); e != nil || username != nameID.Value {
logger.Warnf("SAML SLO rejected: NameID doesn't match current session. Caused by: %v", e)
return ErrorSamlSloResponder.WithMessage("NameID doesn't match current session")
}
}
return nil
}
func (mw *SamlSingleLogoutMiddleware) populateContext(gc *gin.Context, req *SamlLogoutRequest) {
gc.Set(ctxKeySloRequest, req)
}
func (mw *SamlSingleLogoutMiddleware) prepareSuccessSamlResponse(ctx context.Context, req *SamlLogoutRequest) error {
resp, e := MakeLogoutResponse(req, saml.StatusSuccess, "")
if e != nil {
logger.WithContext(ctx).Warnf("SAML SLO unable to sign logout response")
return security.NewAuthenticationWarningError("Unable to send SAML Logout Response")
}
req.Response = resp
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"crypto/x509"
"encoding/base64"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
"github.com/crewjam/saml"
"net/http"
"regexp"
"time"
)
type SamlLogoutRequest struct {
HTTPRequest *http.Request
Binding string
Request *saml.LogoutRequest
RequestBuffer []byte
RelayState string
IDP *saml.IdentityProvider
SPMeta *saml.EntityDescriptor // the requester
SPSSODescriptor *saml.SPSSODescriptor
Callback *saml.Endpoint
Response *saml.LogoutResponse
}
func (r SamlLogoutRequest) Validate() error {
now := time.Now()
if r.Request.Destination != "" && r.Request.Destination != r.IDP.LogoutURL.String() {
return ErrorSamlSloResponder.WithMessage("expected destination to be %q, not %q", r.IDP.LogoutURL.String(), r.Request.Destination)
}
if r.Request.IssueInstant.Add(saml.MaxIssueDelay).Before(now) {
return ErrorSamlSloResponder.WithMessage("request expired at %s", r.Request.IssueInstant.Add(saml.MaxIssueDelay))
}
if r.Request.Version != "2.0" {
return NewSamlRequestVersionMismatch("expected saml version 2.0")
}
if r.Request.NameID == nil || len(r.Request.NameID.Value) == 0 {
return ErrorSamlSloResponder.WithMessage("request missing saml:NameID")
}
return nil
}
func (r SamlLogoutRequest) VerifySignature() error {
cert, e := r.serviceProviderCert("signing")
if e != nil {
return ErrorSamlSloResponder.WithMessage("logout request signature cannot be verified, because metadata does not include certificate")
}
return samlutils.VerifySignature(func(sc *samlutils.SignatureContext) {
sc.Binding = r.Binding
sc.XMLData = r.RequestBuffer
sc.Certs = []*x509.Certificate{cert}
sc.Request = r.HTTPRequest
})
}
func (r SamlLogoutRequest) WriteResponse(rw http.ResponseWriter) error {
if r.Response == nil {
return ErrorSamlSloRequester.WithMessage("logout response is not available")
}
// the only supported binding is the HTTP-POST binding, so don't need to apply Redirect fix
switch r.Callback.Binding {
case saml.HTTPPostBinding:
data := r.Response.Post(r.RelayState)
if e := samlutils.WritePostBindingHTML(data, rw); e != nil {
return ErrorSamlSloRequester.WithMessage("unable to write response: %v", e)
}
default:
return ErrorSamlSloRequester.WithMessage("%s: unsupported binding %s", r.SPMeta.EntityID, r.Callback.Binding)
}
return nil
}
func (r SamlLogoutRequest) serviceProviderCert(usage string) (*x509.Certificate, error) {
certStr := ""
for _, keyDescriptor := range r.SPSSODescriptor.KeyDescriptors {
if keyDescriptor.Use == usage && len(keyDescriptor.KeyInfo.X509Data.X509Certificates) > 0 {
certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data
break
}
}
// If there are no certs explicitly labeled for encryption, return the first
// non-empty cert we find.
if certStr == "" {
for _, keyDescriptor := range r.SPSSODescriptor.KeyDescriptors {
if keyDescriptor.Use == "" &&
len(keyDescriptor.KeyInfo.X509Data.X509Certificates) > 0 &&
keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data != "" {
certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data
break
}
}
}
if certStr == "" {
return nil, NewSamlInternalError("certificate not found")
}
// cleanup whitespace and re-encode a PEM
certStr = regexp.MustCompile(`\s+`).ReplaceAllString(certStr, "")
certBytes, err := base64.StdEncoding.DecodeString(certStr)
if err != nil {
return nil, NewSamlInternalError("cannot decode certificate base64: %v", err)
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, NewSamlInternalError("cannot parse certificate: %v", err)
}
return cert, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"crypto/tls"
"fmt"
"github.com/beevik/etree"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
dsig "github.com/russellhaering/goxmldsig"
)
func MakeLogoutResponse(req *SamlLogoutRequest, code string, message string) (*saml.LogoutResponse, error) {
now := saml.TimeNow()
response := saml.LogoutResponse{
Destination: req.Callback.ResponseLocation,
ID: fmt.Sprintf("id-%x", cryptoutils.RandomBytes(20)),
InResponseTo: req.Request.ID,
IssueInstant: now,
Version: "2.0",
Issuer: &saml.Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: req.IDP.MetadataURL.String(),
},
Status: saml.Status{
StatusCode: saml.StatusCode{
Value: code,
},
},
}
if len(message) != 0 {
response.Status.StatusMessage = &saml.StatusMessage{
Value: message,
}
}
if len(req.IDP.SignatureMethod) == 0 {
req.IDP.SignatureMethod = dsig.RSASHA1SignatureMethod
}
if e := SignLogoutResponse(req.IDP, &response); e != nil {
return nil, e
}
req.Response = &response
return &response, nil
}
// SignLogoutResponse is similar to saml.ServiceProvider.SignLogoutResponse, but for IDP
func SignLogoutResponse(idp *saml.IdentityProvider, resp *saml.LogoutResponse) error {
keyPair := tls.Certificate{
Certificate: [][]byte{idp.Certificate.Raw},
PrivateKey: idp.Key,
Leaf: idp.Certificate,
}
// TODO: add intermediates for SP
//for _, cert := range sp.Intermediates {
// keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
//}
keyStore := dsig.TLSCertKeyStore(keyPair)
if idp.SignatureMethod != dsig.RSASHA1SignatureMethod &&
idp.SignatureMethod != dsig.RSASHA256SignatureMethod &&
idp.SignatureMethod != dsig.RSASHA512SignatureMethod {
return fmt.Errorf("invalid signing method %s", idp.SignatureMethod)
}
signatureMethod := idp.SignatureMethod
signingContext := dsig.NewDefaultSigningContext(keyStore)
signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
return err
}
assertionEl := resp.Element()
signedRequestEl, err := signingContext.SignEnveloped(assertionEl)
if err != nil {
return err
}
sigEl := signedRequestEl.Child[len(signedRequestEl.Child)-1]
resp.Signature = sigEl.(*etree.Element)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"context"
"crypto/tls"
"fmt"
"github.com/beevik/etree"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
"github.com/crewjam/saml/xmlenc"
dsig "github.com/russellhaering/goxmldsig"
)
const canonicalizerPrefixList = ""
type AttributeGenerator func(account security.Account) []saml.Attribute
// MakeAssertion This is similar to the method in saml.IdpAuthnRequest
// but we have our own logic for generating attributes.
func MakeAssertion(ctx context.Context, req *saml.IdpAuthnRequest, authentication security.Authentication, generator AttributeGenerator) error {
username, err := security.GetUsername(authentication)
if err != nil {
return NewSamlInternalError("can't get username from authentication", err)
}
attributes := []saml.Attribute{}
attributes = append(attributes, saml.Attribute{
Name: "urn:mace:dir:attribute-def:uid",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
Values: []saml.AttributeValue{{
Type: "xs:string",
Value: username,
}},
})
attributes = append(attributes, saml.Attribute{
Name: "Username",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified",
Values: []saml.AttributeValue{{
Type: "xs:string",
Value: username,
}},
})
acct, ok := authentication.Principal().(security.Account)
if generator != nil && ok {
additionalAttributes := generator(acct)
if len(additionalAttributes) > 0 {
attributes = append(attributes, additionalAttributes...)
}
}
// allow for some clock skew in the validity period using the
// issuer's apparent clock.
notBefore := req.Now.Add(-1 * saml.MaxClockSkew)
notOnOrAfterAfter := req.Now.Add(saml.MaxIssueDelay)
if notBefore.Before(req.Request.IssueInstant) {
notBefore = req.Request.IssueInstant
notOnOrAfterAfter = notBefore.Add(saml.MaxIssueDelay)
}
authCtxClassRef := "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
details, ok := authentication.Details().(map[string]interface{})
if ok {
if authMethod, exist := details[security.DetailsKeyAuthMethod]; exist {
switch authMethod {
case security.AuthMethodPassword:
authCtxClassRef = "urn:oasis:names:tc:SAML:2.0:ac:classes:Password"
case security.AuthMethodExternalSaml:
authCtxClassRef = "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocol"
case security.AuthMethodExternalOpenID:
authCtxClassRef = "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocol"
}
}
}
req.Assertion = &saml.Assertion{
ID: fmt.Sprintf("id-%x", cryptoutils.RandomBytes(20)),
IssueInstant: saml.TimeNow(),
Version: "2.0",
Issuer: saml.Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: req.IDP.Metadata().EntityID,
},
Subject: &saml.Subject{
NameID: &saml.NameID{
Format: "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified",
Value: username,
},
SubjectConfirmations: []saml.SubjectConfirmation{
{
Method: "urn:oasis:names:tc:SAML:2.0:cm:bearer",
SubjectConfirmationData: &saml.SubjectConfirmationData{
InResponseTo: req.Request.ID,
NotOnOrAfter: req.Now.Add(saml.MaxIssueDelay),
Recipient: req.ACSEndpoint.Location,
},
},
},
},
Conditions: &saml.Conditions{
NotBefore: notBefore,
NotOnOrAfter: notOnOrAfterAfter,
AudienceRestrictions: []saml.AudienceRestriction{
{
Audience: saml.Audience{Value: req.ServiceProviderMetadata.EntityID},
},
},
},
AuthnStatements: []saml.AuthnStatement{
{
AuthnInstant: security.DetermineAuthenticationTime(ctx, authentication),
AuthnContext: saml.AuthnContext{
AuthnContextClassRef: &saml.AuthnContextClassRef{
Value: authCtxClassRef,
},
},
},
},
AttributeStatements: []saml.AttributeStatement{
{
Attributes: attributes,
},
},
}
return nil
}
// MakeAssertionEl This is similar to the implementation in saml.IdpAuthnRequest
// we re-implement it here because we need to optionally skip encryption
func MakeAssertionEl(req *saml.IdpAuthnRequest, skipEncryption bool) error {
keyPair := tls.Certificate{
Certificate: [][]byte{req.IDP.Certificate.Raw},
PrivateKey: req.IDP.Key,
Leaf: req.IDP.Certificate,
}
for _, cert := range req.IDP.Intermediates {
keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
}
keyStore := dsig.TLSCertKeyStore(keyPair)
signatureMethod := req.IDP.SignatureMethod
if signatureMethod == "" {
signatureMethod = dsig.RSASHA1SignatureMethod
}
signingContext := dsig.NewDefaultSigningContext(keyStore)
//This canonicalizer is used to canonicalize a subdocument in such a way that it is substantially independent of its XML context
//because we want to sign the assertion payload indpendent of the response envelope,
//so we don't want to canonicalize the assertion's element with the evenlope's name space.
//we give an empty prefix list because we don't want any of the envelope's name space.
signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
return NewSamlResponderError("unsupported signature method for signing assertion", err)
}
assertionEl := req.Assertion.Element()
signedAssertionEl, err := signingContext.SignEnveloped(assertionEl)
if err != nil {
return NewSamlResponderError("error signing assertion", err)
}
sigEl := signedAssertionEl.Child[len(signedAssertionEl.Child)-1]
req.Assertion.Signature = sigEl.(*etree.Element)
signedAssertionEl = req.Assertion.Element()
if skipEncryption {
req.AssertionEl = signedAssertionEl
return nil
}
certBuf, err := getServiceProviderCert(req, "encryption")
if err != nil {
return NewSamlRequesterError("requester doesn't provide encryption key in metadata")
}
var signedAssertionBuf []byte
{
doc := etree.NewDocument()
doc.SetRoot(signedAssertionEl)
signedAssertionBuf, err = doc.WriteToBytes()
if err != nil {
return err
}
}
encryptor := xmlenc.OAEP()
encryptor.BlockCipher = xmlenc.AES128CBC
encryptor.DigestMethod = &xmlenc.SHA1
encryptedDataEl, err := encryptor.Encrypt(certBuf, signedAssertionBuf, nil)
if err != nil {
return NewSamlResponderError("error signing assertion")
}
encryptedDataEl.CreateAttr("Type", "http://www.w3.org/2001/04/xmlenc#Element")
encryptedAssertionEl := etree.NewElement("saml:EncryptedAssertion")
encryptedAssertionEl.AddChild(encryptedDataEl)
req.AssertionEl = encryptedAssertionEl
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"bytes"
"crypto/x509"
"encoding/xml"
"fmt"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
"github.com/crewjam/saml"
xrv "github.com/mattermost/xml-roundtrip-validator"
"net/http"
"strconv"
)
func UnmarshalRequest(req *saml.IdpAuthnRequest) error {
if err := xrv.Validate(bytes.NewReader(req.RequestBuffer)); err != nil {
return NewSamlRequesterError("authentication request is not valid xml", err)
}
if err := xml.Unmarshal(req.RequestBuffer, &req.Request); err != nil {
return NewSamlInternalError("error unmarshal authentication request xml", err)
}
return nil
}
// ValidateAuthnRequest This method is similar to the method in saml.IdpAuthnRequest,
// Because the original implementation doesn't support signature check and destination check,
// we reimplement it here to add support for them
func ValidateAuthnRequest(req *saml.IdpAuthnRequest, spDetails SamlSpDetails, spMetadata *saml.EntityDescriptor) error {
if !spDetails.SkipAuthRequestSignatureVerification {
if err := verifySignature(req); err != nil {
return NewSamlRequesterError("request signature cannot be verified", err)
}
}
if req.Request.Destination != "" && req.Request.Destination != req.IDP.SSOURL.String() {
return NewSamlResponderError(fmt.Sprintf("expected destination to be %q, not %q", req.IDP.SSOURL.String(), req.Request.Destination))
}
if req.Request.IssueInstant.Add(saml.MaxIssueDelay).Before(req.Now) {
return NewSamlResponderError(fmt.Sprintf("request expired at %s",req.Request.IssueInstant.Add(saml.MaxIssueDelay)))
}
if req.Request.Version != "2.0" {
return NewSamlRequestVersionMismatch("expected saml version 2.0")
}
return nil
}
func verifySignature(req *saml.IdpAuthnRequest) error {
binding := saml.HTTPPostBinding
if req.HTTPRequest.Method == http.MethodGet {
binding = saml.HTTPRedirectBinding
}
cert, err := getServiceProviderCert(req,"signing")
if err != nil {
return NewSamlRequesterError("request signature cannot be verified, because metadata does not include certificate", err)
}
return samlutils.VerifySignature(func(sc *samlutils.SignatureContext) {
sc.Binding = binding
sc.XMLData = req.RequestBuffer
sc.Certs = []*x509.Certificate{cert}
sc.Request = req.HTTPRequest
})
}
func DetermineACSEndpoint(req *saml.IdpAuthnRequest) error {
//get by index
if req.Request.AssertionConsumerServiceIndex != "" {
for _, spAssertionConsumerService := range req.SPSSODescriptor.AssertionConsumerServices {
if strconv.Itoa(spAssertionConsumerService.Index) == req.Request.AssertionConsumerServiceIndex {
v := spAssertionConsumerService
req.ACSEndpoint = &v
return nil
}
}
}
//get by location
if req.Request.AssertionConsumerServiceURL != "" {
for _, spAssertionConsumerService := range req.SPSSODescriptor.AssertionConsumerServices {
if spAssertionConsumerService.Location == req.Request.AssertionConsumerServiceURL {
v := spAssertionConsumerService
req.ACSEndpoint = &v
return nil
}
}
}
// Some service providers, like the Microsoft Azure AD service provider, issue
// assertion requests that don't specify an ACS url at all.
if req.Request.AssertionConsumerServiceURL == "" && req.Request.AssertionConsumerServiceIndex == "" {
// find a default ACS binding in the metadata that we can use
for _, spAssertionConsumerService := range req.SPSSODescriptor.AssertionConsumerServices {
if spAssertionConsumerService.IsDefault != nil && *spAssertionConsumerService.IsDefault {
switch spAssertionConsumerService.Binding {
case saml.HTTPPostBinding, saml.HTTPRedirectBinding:
v := spAssertionConsumerService
req.ACSEndpoint = &v
return nil
}
}
}
// if we can't find a default, use *any* ACS binding
for _, spAssertionConsumerService := range req.SPSSODescriptor.AssertionConsumerServices {
switch spAssertionConsumerService.Binding {
case saml.HTTPPostBinding, saml.HTTPRedirectBinding:
v := spAssertionConsumerService
req.ACSEndpoint = &v
return nil
}
}
}
return NewSamlRequesterError("assertion consumer service not found")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"net/http"
)
type SamlAuthorizeEndpointConfigurer struct {
samlConfigurer
accountStore security.AccountStore
attributeGenerator AttributeGenerator
}
func newSamlAuthorizeEndpointConfigurer(properties samlctx.SamlProperties,
samlClientStore samlctx.SamlClientStore,
accountStore security.AccountStore,
attributeGenerator AttributeGenerator) *SamlAuthorizeEndpointConfigurer {
return &SamlAuthorizeEndpointConfigurer{
samlConfigurer: samlConfigurer{
properties: properties,
samlClientStore: samlClientStore,
},
accountStore: accountStore,
attributeGenerator: attributeGenerator,
}
}
func (c *SamlAuthorizeEndpointConfigurer) Apply(feature security.Feature, ws security.WebSecurity) (err error) {
f := feature.(*Feature)
metaMw := c.metadataMiddleware(f)
mw := NewSamlAuthorizeEndpointMiddleware(metaMw, c.accountStore, c.attributeGenerator)
ws.
Add(middleware.NewBuilder("Saml Service Provider Refresh").
ApplyTo(matcher.RouteWithPattern(f.ssoLocation.Path, http.MethodGet, http.MethodPost)).
Order(security.MWOrderSAMLMetadataRefresh).
Use(mw.RefreshMetadataHandler(f.ssoCondition))).
Add(middleware.NewBuilder("Saml SSO").
ApplyTo(matcher.RouteWithPattern(f.ssoLocation.Path, http.MethodGet, http.MethodPost)).
Order(security.MWOrderSamlAuthEndpoints).
Use(mw.AuthorizeHandlerFunc(f.ssoCondition)))
ws.Add(mapping.Get(f.ssoLocation.Path).HandlerFunc(security.NoopHandlerFunc()))
ws.Add(mapping.Post(f.ssoLocation.Path).HandlerFunc(security.NoopHandlerFunc()))
//metadata is an actual endpoint
ws.Add(mapping.Get(f.metadataPath).
HandlerFunc(mw.MetadataHandlerFunc()).
Name("saml metadata"))
// configure error handling
errorhandling.Configure(ws).
AdditionalErrorHandler(NewSamlErrorHandler())
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"crypto/tls"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
dsig "github.com/russellhaering/goxmldsig"
)
func MakeErrorResponse(req *saml.IdpAuthnRequest, code string, message string) error {
response := &saml.Response{
Destination: req.ACSEndpoint.Location,
ID: fmt.Sprintf("id-%x", cryptoutils.RandomBytes(20)),
InResponseTo: req.Request.ID,
IssueInstant: req.Now,
Version: "2.0",
Issuer: &saml.Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: req.IDP.MetadataURL.String(),
},
Status: saml.Status{
StatusCode: saml.StatusCode{
Value: code,
},
StatusMessage: &saml.StatusMessage{
Value: message,
},
},
}
responseEl := response.Element()
// Sign the response element
{
keyPair := tls.Certificate{
Certificate: [][]byte{req.IDP.Certificate.Raw},
PrivateKey: req.IDP.Key,
Leaf: req.IDP.Certificate,
}
for _, cert := range req.IDP.Intermediates {
keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
}
keyStore := dsig.TLSCertKeyStore(keyPair)
signatureMethod := req.IDP.SignatureMethod
if signatureMethod == "" {
signatureMethod = dsig.RSASHA1SignatureMethod
}
signingContext := dsig.NewDefaultSigningContext(keyStore)
signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
return err
}
signedResponseEl, err := signingContext.SignEnveloped(responseEl)
if err != nil {
return err
}
sigEl := signedResponseEl.ChildElements()[len(signedResponseEl.ChildElements())-1]
response.Signature = sigEl
responseEl = response.Element()
}
req.ResponseEl = responseEl
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/crewjam/saml"
"github.com/gin-gonic/gin"
)
type SamlAuthorizeEndpointMiddleware struct {
*MetadataMiddleware
accountStore security.AccountStore
attributeGenerator AttributeGenerator
}
func NewSamlAuthorizeEndpointMiddleware(metaMw *MetadataMiddleware,
accountStore security.AccountStore,
attributeGenerator AttributeGenerator) *SamlAuthorizeEndpointMiddleware {
mw := &SamlAuthorizeEndpointMiddleware{
MetadataMiddleware: metaMw,
accountStore: accountStore,
attributeGenerator: attributeGenerator,
}
return mw
}
func (mw *SamlAuthorizeEndpointMiddleware) AuthorizeHandlerFunc(condition web.RequestMatcher) gin.HandlerFunc {
return func(ctx *gin.Context) {
if matches, err := condition.MatchesWithContext(ctx, ctx.Request); !matches || err != nil {
return
}
var req *saml.IdpAuthnRequest
var err error
idpInitiatedMatcher := matcher.RequestWithForm("idp_init", "true")
isIdpInit, _ := idpInitiatedMatcher.Matches(ctx.Request)
if isIdpInit {
entityId := ctx.Request.Form.Get("entity_id")
if entityId == "" {
mw.handleError(ctx, nil, NewSamlInternalError("error start idp initiated sso, no SP entity id provided"))
return
}
req = &saml.IdpAuthnRequest{
Request: saml.AuthnRequest{
Issuer: &saml.Issuer{
Value: entityId,
},
IssueInstant: saml.TimeNow(),
},
IDP: mw.idp,
Now: saml.TimeNow(),
}
} else {
req, err = saml.NewIdpAuthnRequest(mw.idp, ctx.Request)
if err != nil {
mw.handleError(ctx, nil, NewSamlInternalError("error decoding authentication request", err))
return
}
if err = UnmarshalRequest(req); err != nil {
mw.handleError(ctx, nil, err)
return
}
}
authentication := security.Get(ctx)
//sanity check
if authentication == nil {
mw.handleError(ctx, nil, NewSamlInternalError("no authentication found", err))
return
}
//sanity check
if authentication.State() < security.StateAuthenticated {
mw.handleError(ctx, nil, NewSamlInternalError("session is not authenticated"))
return
}
serviceProviderID := req.Request.Issuer.Value
// find the service provider metadata
spDetails, spMetadata, err := mw.spMetadataManager.GetServiceProvider(serviceProviderID)
if err != nil {
mw.handleError(ctx, nil, NewSamlInternalError("cannot find service provider metadata"))
return
}
if len(spMetadata.SPSSODescriptors) != 1 {
mw.handleError(ctx, nil, NewSamlInternalError("expected exactly one SP SSO descriptor in SP metadata"))
return
}
req.ServiceProviderMetadata = spMetadata
req.SPSSODescriptor = &spMetadata.SPSSODescriptors[0]
// Check that the ACS URL matches an ACS endpoint in the SP metadata.
// After this point, we have the endpoint to send back responses whether it's success or false
if err = DetermineACSEndpoint(req); err != nil {
mw.handleError(ctx, nil, err)
return
}
if !isIdpInit {
if err = ValidateAuthnRequest(req, spDetails, spMetadata); err != nil {
mw.handleError(ctx, req, err)
return
}
}
//check tenancy
client, err := mw.samlClientStore.GetSamlClientByEntityId(ctx.Request.Context(), serviceProviderID)
if err != nil { //we shouldn't get an error here because we already have the SP's metadata.
//if an error does occur, it means there's a programming error
mw.handleError(ctx, nil, NewSamlInternalError("saml client not found", err))
return
}
err = mw.validateTenantRestriction(ctx, client, authentication)
if err != nil {
mw.handleError(ctx, req, err)
return
}
if err = MakeAssertion(ctx, req, authentication, mw.attributeGenerator); err != nil {
mw.handleError(ctx, req, err)
return
}
if err = MakeAssertionEl(req, spDetails.SkipAssertionEncryption); err != nil {
mw.handleError(ctx, req, err)
return
}
if err = req.WriteResponse(ctx.Writer); err != nil {
mw.handleError(ctx, nil, NewSamlInternalError("error writing saml response", err))
return
} else {
//abort the rest of the handlers because we have already written the response successfully
ctx.Abort()
}
}
}
func (mw *SamlAuthorizeEndpointMiddleware) handleError(c *gin.Context, authRequest *saml.IdpAuthnRequest, err error) {
if !errors.Is(err, security.ErrorTypeSaml) {
err = NewSamlInternalError("saml sso internal error", err)
}
if authRequest != nil {
c.Set(CtxKeySamlAuthnRequest, authRequest)
}
_ = c.Error(err)
c.Abort()
}
func (mw *SamlAuthorizeEndpointMiddleware) validateTenantRestriction(ctx context.Context, client samlctx.SamlClient, auth security.Authentication) error {
tenantRestriction := client.GetTenantRestrictions()
if len(tenantRestriction) == 0 {
return nil
}
username, e := security.GetUsername(auth)
if e != nil {
return NewSamlInternalError("cannot validate tenancy restriction due to unknown username", e)
}
acct, e := mw.accountStore.LoadAccountByUsername(ctx, username)
if e != nil {
return NewSamlInternalError("cannot validate tenancy restriction due to error fetching account", e)
}
acctTenancy, ok := acct.(security.AccountTenancy)
if !ok {
return NewSamlInternalError(fmt.Sprintf("cannot validate tenancy restriction due to unsupported account implementation: %T", acct))
}
userAccessibleTenants := utils.NewStringSet(acctTenancy.DesignatedTenantIds()...)
if userAccessibleTenants.Has(security.SpecialTenantIdWildcard) {
return nil
}
switch tenantRestrictionType := client.GetTenantRestrictionType(); tenantRestrictionType {
case TenantRestrictionTypeAny:
allowed := false
for t := range tenantRestriction {
if tenancy.AnyHasDescendant(ctx, userAccessibleTenants, t) {
allowed = true
break
}
}
if !allowed {
return NewSamlInternalError("client is restricted to tenants which the authenticated user does not have access to")
}
default: //default to TenantRestrictionTypeAll
for t := range tenantRestriction {
if !tenancy.AnyHasDescendant(ctx, userAccessibleTenants, t) {
return NewSamlInternalError("client is restricted to tenants which the authenticated user does not have access to")
}
}
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlidp
import (
"crypto/x509"
"encoding/base64"
"github.com/crewjam/saml"
"regexp"
)
func getServiceProviderCert(req *saml.IdpAuthnRequest, usage string) (*x509.Certificate, error) {
certStr := ""
for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors {
if keyDescriptor.Use == usage && len(keyDescriptor.KeyInfo.X509Data.X509Certificates) > 0 {
certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data
break
}
}
// If there are no certs explicitly labeled for encryption, return the first
// non-empty cert we find.
if certStr == "" {
for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors {
if keyDescriptor.Use == "" &&
len(keyDescriptor.KeyInfo.X509Data.X509Certificates) > 0 &&
keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data != "" {
certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data
break
}
}
}
if certStr == "" {
return nil, NewSamlInternalError("certificate not found")
}
// cleanup whitespace and re-encode a PEM
certStr = regexp.MustCompile(`\s+`).ReplaceAllString(certStr, "")
certBytes, err := base64.StdEncoding.DecodeString(certStr)
if err != nil {
return nil, NewSamlInternalError("cannot decode certificate base64: %v", err)
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, NewSamlInternalError("cannot parse certificate: %v", err)
}
return cert, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlctx
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "saml",
Precedence: security.MinSecurityPrecedence + 20,
Options: []fx.Option{
fx.Provide(BindSamlProperties),
},
}
func init() {
bootstrap.Register(Module)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlctx
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/pkg/errors"
)
const SamlPropertiesPrefix = "security.auth.saml"
type SamlProperties struct {
CertificateFile string `json:"certificate-file"`
KeyFile string `json:"key-file"`
KeyPassword string `json:"key-password"`
NameIDFormat string `json:"name-id-format"`
}
func NewSamlProperties() *SamlProperties {
return &SamlProperties{
//We use this property by default so that the auth request generated by the saml package will not
//have NameIDFormat by default
//See saml.nameIDFormat() in github.com/crewjam/saml
NameIDFormat: "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified",
}
}
func BindSamlProperties(ctx *bootstrap.ApplicationContext) SamlProperties {
props := NewSamlProperties()
if err := ctx.Config().Bind(props, SamlPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind SamlProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/crewjam/saml"
)
type AssertionCandidate struct {
Assertion *saml.Assertion
DetailsMap map[string]interface{}
}
func (a *AssertionCandidate) Principal() interface{} {
if a.Assertion.Subject == nil || a.Assertion.Subject.NameID == nil {
return nil
}
return a.Assertion.Subject.NameID.Value
}
func (a *AssertionCandidate) Credentials() interface{} {
return a.Assertion
}
func (a *AssertionCandidate) Details() interface{} {
return a.DetailsMap
}
type SamlAssertionAuthentication interface {
security.Authentication
Assertion() *saml.Assertion
}
type samlAssertionAuthentication struct {
Account security.Account
Perms map[string]interface{}
DetailsMap map[string]interface{}
SamlAssertion *saml.Assertion
}
func (sa *samlAssertionAuthentication) Principal() interface{} {
return sa.Account
}
func (sa *samlAssertionAuthentication) Permissions() security.Permissions {
return sa.Perms
}
func (sa *samlAssertionAuthentication) State() security.AuthenticationState {
return security.StateAuthenticated
}
func (sa *samlAssertionAuthentication) Details() interface{} {
return sa.DetailsMap
}
func (sa *samlAssertionAuthentication) Assertion() *saml.Assertion {
return sa.SamlAssertion
}
type Authenticator struct {
accountStore security.FederatedAccountStore
idpManager samlctx.SamlIdentityProviderManager
}
func (a *Authenticator) Authenticate(ctx context.Context, candidate security.Candidate) (security.Authentication, error) {
assertionCandidate, ok := candidate.(*AssertionCandidate)
if !ok {
return nil, nil
}
idp, err := a.idpManager.GetIdentityProviderByEntityId(ctx, assertionCandidate.Assertion.Issuer.Value)
if err != nil {
return nil, security.NewInternalAuthenticationError("Couldn't find idp matching the assertion")
}
samlIdp, ok := idp.(samlctx.SamlIdentityProvider)
if !ok {
return nil, security.NewInternalAuthenticationError("Couldn't find idp metadata matching the assertion")
}
user, err := a.accountStore.LoadAccountByExternalId(ctx, samlIdp.ExternalIdName(), assertionCandidate.Principal().(string), samlIdp.ExternalIdpName(), samlIdp.GetAutoCreateUserDetails(), assertionCandidate.Assertion)
if err != nil {
return nil, security.NewInternalAuthenticationError(err)
}
if user.Disabled() {
return nil, security.NewAccountStatusError("Account Disabled")
}
permissions := map[string]interface{}{}
for _, p := range user.Permissions() {
permissions[p] = true
}
details := assertionCandidate.DetailsMap
if details == nil {
details = make(map[string]interface{})
}
details[security.DetailsKeyAuthTime] = assertionCandidate.Assertion.IssueInstant
details[security.DetailsKeyAuthMethod] = security.AuthMethodExternalSaml
auth := &samlAssertionAuthentication{
Account: user,
SamlAssertion: assertionCandidate.Assertion,
Perms: permissions,
DetailsMap: details,
}
return auth, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"context"
"crypto/x509"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
"net"
"net/http"
"reflect"
"sync"
)
type CacheableIdpClientManager struct {
template saml.ServiceProvider
//for fetching idp metadata
httpClient *http.Client
cache map[string]*saml.ServiceProvider
processed map[string]samlctx.SamlIdentityProvider
cacheMutex sync.RWMutex
}
func NewCacheableIdpClientManager(template saml.ServiceProvider) *CacheableIdpClientManager {
return &CacheableIdpClientManager{
template: template,
httpClient: http.DefaultClient,
cache: make(map[string]*saml.ServiceProvider),
processed: make(map[string]samlctx.SamlIdentityProvider),
}
}
func (m *CacheableIdpClientManager) RefreshCache(ctx context.Context, identityProviders []samlctx.SamlIdentityProvider) {
m.cacheMutex.RLock()
remove, refresh := m.compareWithCache(identityProviders)
m.cacheMutex.RUnlock()
//nothing changed, just return
if len(refresh) == 0 && len(remove) == 0{
return
}
m.cacheMutex.Lock()
defer m.cacheMutex.Unlock()
//check again in case another process has already made the update
remove, refresh = m.compareWithCache(identityProviders)
if len(refresh) == 0 && len(remove) == 0{
return
}
resolved := m.resolveMetadata(ctx, refresh)
for entityId, doRemove := range remove {
if doRemove {
delete(m.cache, entityId)
delete(m.processed, entityId)
}
}
for _, details := range refresh {
if client, ok := resolved[details.EntityId()]; ok {
m.cache[details.EntityId()] = client
m.processed[details.EntityId()] = details
}
}
}
func (m *CacheableIdpClientManager) compareWithCache(identityProviders []samlctx.SamlIdentityProvider) (remove map[string]bool, refresh []samlctx.SamlIdentityProvider) {
keep := make(map[string]bool)
remove = make(map[string]bool)
for _, details := range identityProviders {
if _, ok := m.cache[details.EntityId()]; !ok {
refresh = append(refresh, details)
} else {
processed := m.processed[details.EntityId()]
if !reflect.DeepEqual(processed, details) {
refresh = append(refresh, details)
} else {
keep[details.EntityId()] = true
}
}
}
for entityId := range m.cache {
if _, ok := keep[entityId]; !ok {
remove[entityId] = true
}
}
return remove, refresh
}
func (m *CacheableIdpClientManager) resolveMetadata(ctx context.Context, refresh []samlctx.SamlIdentityProvider) (resolved map[string]*saml.ServiceProvider){
resolved = make(map[string]*saml.ServiceProvider)
for _, details := range refresh {
idpDescriptor, data, err := samlutils.ResolveMetadata(ctx, details.MetadataLocation(), samlutils.WithHttpClient(m.httpClient))
if err == nil {
if details.ShouldMetadataRequireSignature() && idpDescriptor.Signature == nil{
logger.WithContext(ctx).Errorf("idp metadata rejected because it is not signed")
continue
}
if details.ShouldMetadataTrustCheck() {
var allCerts []*x509.Certificate
for _, keyLoc := range details.GetMetadataTrustedKeys() {
certs, err := cryptoutils.LoadCert(keyLoc)
if err == nil {
allCerts = append(allCerts, certs...)
}
}
err = samlutils.VerifySignature(samlutils.MetadataSignature(data, allCerts...))
if err != nil {
logger.WithContext(ctx).Errorf("idp metadata rejected because it's signature cannot be verified")
continue
}
}
//make a copy
client := m.template
client.IDPMetadata = idpDescriptor
_, port, err := net.SplitHostPort(client.AcsURL.Host)
if err == nil {
client.AcsURL.Host = net.JoinHostPort(details.Domain(), port)
} else {
client.AcsURL.Host = details.Domain()
}
_, port, err = net.SplitHostPort(client.SloURL.Host)
if err == nil {
client.SloURL.Host = net.JoinHostPort(details.Domain(), port)
} else {
client.SloURL.Host = details.Domain()
}
resolved[details.EntityId()] = &client
}
}
return resolved
}
func (m *CacheableIdpClientManager) GetAllClients() []*saml.ServiceProvider {
m.cacheMutex.RLock()
defer m.cacheMutex.RUnlock()
clients := make([]*saml.ServiceProvider, len(m.cache))
idx := 0
for _, client := range m.cache {
clients[idx] = client
idx++
}
return clients
}
func (m *CacheableIdpClientManager) GetClientByComparator(comparator func(details samlctx.SamlIdentityProvider) bool) (client *saml.ServiceProvider, ok bool) {
m.cacheMutex.RLock()
defer m.cacheMutex.RUnlock()
for entityId, details := range m.processed {
if comparator(details) {
return m.cache[entityId], true
}
}
return nil, false
}
func (m *CacheableIdpClientManager) GetClientByDomain(domain string) (client *saml.ServiceProvider, ok bool) {
return m.GetClientByComparator(func(details samlctx.SamlIdentityProvider) bool {
return details.Domain() == domain
})
}
func (m *CacheableIdpClientManager) GetClientByEntityId(entityId string) (client *saml.ServiceProvider, ok bool) {
return m.GetClientByComparator(func(details samlctx.SamlIdentityProvider) bool {
return details.EntityId() == entityId
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"crypto/rsa"
"crypto/x509"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/idp"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/golang-jwt/jwt/v4"
dsig "github.com/russellhaering/goxmldsig"
"net/http"
"net/url"
)
type spOptionsHashable struct {
URL url.URL
ACSPath string
MetadataPath string
SLOPath string
}
type SPOptions struct {
spOptionsHashable
Key *rsa.PrivateKey
Certificate *x509.Certificate
Intermediates []*x509.Certificate
AllowIDPInitiated bool
SignRequest bool
ForceAuthn bool
NameIdFormat string
}
type configurerSharedComponents struct {
serviceProvider *saml.ServiceProvider
tracker samlsp.RequestTracker
clientManager *CacheableIdpClientManager
}
// samlConfigurer is a base implementation for both login and logout configurer.
// Many components for login and logout are shared
type samlConfigurer struct {
properties samlctx.SamlProperties
idpManager idp.IdentityProviderManager
samlIdpManager samlctx.SamlIdentityProviderManager
// Shared components, generated on demand
components map[spOptionsHashable]*configurerSharedComponents
}
func newSamlConfigurer(properties samlctx.SamlProperties, idpManager idp.IdentityProviderManager) *samlConfigurer {
return &samlConfigurer{
properties: properties,
idpManager: idpManager,
samlIdpManager: idpManager.(samlctx.SamlIdentityProviderManager),
}
}
// shared grab shared component based on issuer. Create if not exists.
// never returns nil
func (c *samlConfigurer) shared(hashable spOptionsHashable) *configurerSharedComponents {
if c.components == nil {
c.components = make(map[spOptionsHashable]*configurerSharedComponents)
}
shared, ok := c.components[hashable]
if !ok {
shared = &configurerSharedComponents{}
c.components[hashable] = shared
}
return shared
}
func (c *samlConfigurer) getServiceProviderConfiguration(f *Feature) (opt SPOptions) {
cert, err := cryptoutils.LoadCert(c.properties.CertificateFile)
if err != nil {
panic(security.NewInternalError("cannot load certificate from file", err))
}
if len(cert) > 1 {
logger.Warnf("multiple certificate found, using first one")
}
key, err := cryptoutils.LoadPrivateKey(c.properties.KeyFile, c.properties.KeyPassword)
if err != nil {
panic(security.NewInternalError("cannot load private key from file", err))
}
rootURL, err := f.issuer.BuildUrl()
if err != nil {
panic(security.NewInternalError("cannot get issuer's base URL", err))
}
opts := SPOptions{
spOptionsHashable: spOptionsHashable{
URL: *rootURL,
ACSPath: fmt.Sprintf("%s%s", rootURL.Path, f.acsPath),
MetadataPath: fmt.Sprintf("%s%s", rootURL.Path, f.metadataPath),
SLOPath: fmt.Sprintf("%s%s", rootURL.Path, f.sloPath),
},
Key: key,
Certificate: cert[0],
SignRequest: true,
NameIdFormat: c.properties.NameIDFormat,
}
return opts
}
func (c *samlConfigurer) sharedServiceProvider(opts SPOptions) (ret saml.ServiceProvider) {
if shared := c.shared(opts.spOptionsHashable); shared.serviceProvider != nil {
return *shared.serviceProvider
} else {
defer func() {
shared.serviceProvider = &ret
}()
}
metadataURL := opts.URL.ResolveReference(&url.URL{Path: opts.MetadataPath})
acsURL := opts.URL.ResolveReference(&url.URL{Path: opts.ACSPath})
sloURL := opts.URL.ResolveReference(&url.URL{Path: opts.SLOPath})
var forceAuthn *bool
if opts.ForceAuthn {
forceAuthn = &opts.ForceAuthn
}
signatureMethod := dsig.RSASHA1SignatureMethod
if !opts.SignRequest {
signatureMethod = ""
}
sp := saml.ServiceProvider{
Key: opts.Key,
Certificate: opts.Certificate,
Intermediates: opts.Intermediates,
MetadataURL: *metadataURL,
AcsURL: *acsURL,
SloURL: *sloURL,
ForceAuthn: forceAuthn,
SignatureMethod: signatureMethod,
AllowIDPInitiated: opts.AllowIDPInitiated,
AuthnNameIDFormat: saml.NameIDFormat(opts.NameIdFormat),
LogoutBindings: []string{saml.HTTPPostBinding},
}
return sp
}
func (c *samlConfigurer) sharedRequestTracker(opts SPOptions) (ret samlsp.RequestTracker) {
if shared := c.shared(opts.spOptionsHashable); shared.tracker != nil {
return shared.tracker
} else {
defer func() {
shared.tracker = ret
}()
}
codec := samlsp.JWTTrackedRequestCodec{
SigningMethod: jwt.SigningMethodRS256,
Audience: opts.URL.String(),
Issuer: opts.URL.String(),
MaxAge: saml.MaxIssueDelay,
Key: opts.Key,
}
//we want to set sameSite to none, which requires scheme to be https
//otherwise we fallback to default mode, which on modern browsers is lax.
//cross site functionality is limited in lax mode. the cookie will only be
//sent cross site within 2 minutes of its creation.
//so with none + https, we make sure production work as expected. and the fallback
//provides limited support for development environment.
secure := opts.URL.Scheme == "https"
sameSite := http.SameSiteDefaultMode
if secure {
sameSite = http.SameSiteNoneMode
}
tracker := CookieRequestTracker{
NamePrefix: "saml_",
Codec: codec,
MaxAge: saml.MaxIssueDelay,
SameSite: sameSite,
Secure: secure,
Path: opts.ACSPath,
}
return tracker
}
func (c *samlConfigurer) sharedClientManager(opts SPOptions) (ret *CacheableIdpClientManager) {
if shared := c.shared(opts.spOptionsHashable); shared.clientManager != nil {
return shared.clientManager
} else {
defer func() {
shared.clientManager = ret
}()
}
sp := c.sharedServiceProvider(opts)
return NewCacheableIdpClientManager(sp)
}
func (c *samlConfigurer) effectiveSuccessHandler(f *Feature, ws security.WebSecurity) security.AuthenticationSuccessHandler {
if globalHandler, ok := ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(security.AuthenticationSuccessHandler); ok {
return security.NewAuthenticationSuccessHandler(globalHandler, f.successHandler)
} else {
return security.NewAuthenticationSuccessHandler(f.successHandler)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"github.com/cisco-open/go-lanai/pkg/security"
)
var (
FeatureId = security.FeatureId("saml_login", security.FeatureOrderSamlLogin)
LogoutFeatureId = security.FeatureId("saml_logout", security.FeatureOrderSamlLogout)
)
type Feature struct {
id security.FeatureIdentifier
metadataPath string
acsPath string
sloPath string
errorPath string //The path to send the user to when authentication error is encountered
successHandler security.AuthenticationSuccessHandler
issuer security.Issuer
}
func new(id security.FeatureIdentifier) *Feature {
return &Feature{
id: id,
metadataPath: "/saml/metadata",
acsPath: "/saml/SSO",
sloPath: "/saml/slo",
errorPath: "/error",
}
}
func New() *Feature {
return new(FeatureId)
}
func NewLogout() *Feature {
return new(LogoutFeatureId)
}
func (f *Feature) Identifier() security.FeatureIdentifier {
return f.id
}
func (f *Feature) Issuer(issuer security.Issuer) *Feature {
f.issuer = issuer
return f
}
func (f *Feature) ErrorPath(path string) *Feature {
f.errorPath = path
return f
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/request_cache"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
type SamlAuthConfigurer struct {
*samlConfigurer
accountStore security.FederatedAccountStore
}
func (c *SamlAuthConfigurer) Apply(feature security.Feature, ws security.WebSecurity) error {
f := feature.(*Feature)
m := c.makeMiddleware(f, ws)
ws.Route(matcher.RouteWithPattern(f.acsPath)).
Route(matcher.RouteWithPattern(f.metadataPath)).
Add(mapping.Get(f.metadataPath).
HandlerFunc(m.MetadataHandlerFunc()).
//metadata is an endpoint that is available without conditions, therefore call Build() to not inherit the ws condition
Name("saml metadata").Build()).
Add(mapping.Post(f.acsPath).
HandlerFunc(m.ACSHandlerFunc()).
Name("saml assertion consumer m")).
Add(middleware.NewBuilder("saml idp metadata refresh").
Order(security.MWOrderSAMLMetadataRefresh).
Use(m.RefreshMetadataHandler()))
requestMatcher := matcher.RequestWithPattern(f.acsPath).Or(matcher.RequestWithPattern(f.metadataPath))
access.Configure(ws).
Request(requestMatcher).WithOrder(order.Highest).PermitAll()
//authentication entry point
errorhandling.Configure(ws).
AuthenticationEntryPoint(request_cache.NewSaveRequestEntryPoint(m))
return nil
}
func (c *SamlAuthConfigurer) makeMiddleware(f *Feature, ws security.WebSecurity) *SPLoginMiddleware {
opts := c.getServiceProviderConfiguration(f)
sp := c.sharedServiceProvider(opts)
clientManager := c.sharedClientManager(opts)
tracker := c.sharedRequestTracker(opts)
if f.successHandler == nil {
f.successHandler = NewTrackedRequestSuccessHandler(tracker)
}
authenticator := &Authenticator{
accountStore: c.accountStore,
idpManager: c.samlIdpManager,
}
return NewLoginMiddleware(sp, tracker, c.idpManager, clientManager, c.effectiveSuccessHandler(f, ws), authenticator, f.errorPath)
}
func newSamlAuthConfigurer(shared *samlConfigurer, accountStore security.FederatedAccountStore) *SamlAuthConfigurer {
return &SamlAuthConfigurer{
samlConfigurer: shared,
accountStore: accountStore,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/crewjam/saml/samlsp"
"net/http"
)
type TrackedRequestSuccessHandler struct {
tracker samlsp.RequestTracker
}
func NewTrackedRequestSuccessHandler(tracker samlsp.RequestTracker) security.AuthenticationSuccessHandler{
return &TrackedRequestSuccessHandler{
tracker: tracker,
}
}
func (t *TrackedRequestSuccessHandler) HandleAuthenticationSuccess(c context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
redirectURI := "/"
if trackedRequestIndex := r.Form.Get("RelayState"); trackedRequestIndex != "" {
trackedRequest, err := t.tracker.GetTrackedRequest(r, trackedRequestIndex)
if err == nil {
redirectURI = trackedRequest.URI
} else {
logger.WithContext(c).Errorf("error getting tracked request %v", err)
}
_ = t.tracker.StopTrackingRequest(rw, r, trackedRequestIndex)
}
http.Redirect(rw, r, redirectURI, 302)
_,_ = rw.Write([]byte{})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/idp"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
netutil "github.com/cisco-open/go-lanai/pkg/utils/net"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/gin-gonic/gin"
"net/http"
)
// SPLoginMiddleware
/**
A SAML service provider should be able to work with multiple identity providers.
Because the saml package assumes a service provider is configured with one idp only,
we use the internal field to store information about this service provider,
and we will create new saml.ServiceProvider struct for each new idp connection when its needed.
*/
type SPLoginMiddleware struct {
SPMetadataMiddleware
// list of bindings, can be saml.HTTPPostBinding or saml.HTTPRedirectBinding
// order indicates preference
requestTracker samlsp.RequestTracker
authenticator security.Authenticator
successHandler security.AuthenticationSuccessHandler
fallbackEntryPoint security.AuthenticationEntryPoint
}
func NewLoginMiddleware(sp saml.ServiceProvider, tracker samlsp.RequestTracker,
idpManager idp.IdentityProviderManager,
clientManager *CacheableIdpClientManager,
handler security.AuthenticationSuccessHandler, authenticator security.Authenticator,
errorPath string) *SPLoginMiddleware {
return &SPLoginMiddleware{
SPMetadataMiddleware: SPMetadataMiddleware{
internal: sp,
idpManager: idpManager,
clientManager: clientManager,
},
requestTracker: tracker,
successHandler: handler,
authenticator: authenticator,
fallbackEntryPoint: redirect.NewRedirectWithURL(errorPath),
}
}
// MakeAuthenticationRequest Since we support multiple domains each with different IDP, the auth request specify which matching ACS should be
// used for IDP to call back.
func (sp *SPLoginMiddleware) MakeAuthenticationRequest(ctx context.Context, r *http.Request, w http.ResponseWriter) error {
host := netutil.GetForwardedHostName(r)
client, ok := sp.clientManager.GetClientByDomain(host)
if !ok {
logger.WithContext(ctx).Debugf("cannot find idp for domain %s", host)
return security.NewExternalSamlAuthenticationError("cannot find idp for this domain")
}
location, binding := sp.resolveBinding(client.GetSSOBindingLocation)
if location == "" {
return security.NewExternalSamlAuthenticationError("idp does not have supported bindings.")
}
// Note: we only support post for result binding
authReq, err := samlutils.NewFixedAuthenticationRequest(client, location, binding, saml.HTTPPostBinding)
if err != nil {
return security.NewExternalSamlAuthenticationError("cannot make auth request to binding location", err)
}
relayState, err := sp.requestTracker.TrackRequest(w, r, authReq.ID)
if err != nil {
return security.NewExternalSamlAuthenticationError("cannot track saml auth request", err)
}
switch binding {
case saml.HTTPRedirectBinding:
if e := sp.redirectBindingExecutor(authReq, relayState, client)(w, r); e != nil {
return security.NewExternalSamlAuthenticationError("cannot make auth request with HTTP redirect binding", e)
}
case saml.HTTPPostBinding:
if e := sp.postBindingExecutor(authReq, relayState)(w, r); e != nil {
return security.NewExternalSamlAuthenticationError("cannot post auth request", e)
}
}
return nil
}
// ACSHandlerFunc Assertion Consumer Service handler endpoint. IDP redirect to this endpoint with authentication response
func (sp *SPLoginMiddleware) ACSHandlerFunc() gin.HandlerFunc {
return func(c *gin.Context) {
resp := saml.Response{}
switch rs := samlutils.ParseSAMLObject(c, &resp); {
case rs.Err != nil:
sp.handleError(c, security.NewExternalSamlAuthenticationError(fmt.Errorf("cannot process ACS request: %v", rs.Err)))
return
case rs.Binding != saml.HTTPPostBinding:
sp.handleError(c, security.NewExternalSamlAuthenticationError(fmt.Errorf("unsupported binding [%s]", rs.Binding)))
return
}
r := c.Request
client, ok := sp.clientManager.GetClientByEntityId(resp.Issuer.Value)
if !ok {
sp.handleError(c, security.NewExternalSamlAuthenticationError("cannot find idp metadata corresponding for assertion"))
return
}
var possibleRequestIDs []string
if sp.internal.AllowIDPInitiated {
possibleRequestIDs = append(possibleRequestIDs, "")
}
trackedRequests := sp.requestTracker.GetTrackedRequests(r)
for _, tr := range trackedRequests {
possibleRequestIDs = append(possibleRequestIDs, tr.SAMLRequestID)
}
assertion, err := client.ParseResponse(r, possibleRequestIDs)
if err != nil {
logger.WithContext(c).Error("error processing assertion", "err", err)
sp.handleError(c, security.NewExternalSamlAuthenticationError(err.Error(), err))
return
}
candidate := &AssertionCandidate{
Assertion: assertion,
}
auth, err := sp.authenticator.Authenticate(c, candidate)
if err != nil {
sp.handleError(c, security.NewExternalSamlAuthenticationError(err))
return
}
before := security.Get(c)
sp.handleSuccess(c, before, auth)
}
}
func (sp *SPLoginMiddleware) Commence(c context.Context, r *http.Request, w http.ResponseWriter, _ error) {
err := sp.MakeAuthenticationRequest(c, r, w)
if err != nil {
sp.fallbackEntryPoint.Commence(c, r, w, err)
}
}
func (sp *SPLoginMiddleware) handleSuccess(c *gin.Context, before, new security.Authentication) {
if new != nil {
security.MustSet(c, new)
}
sp.successHandler.HandleAuthenticationSuccess(c, c.Request, c.Writer, before, new)
if c.Writer.Written() {
c.Abort()
}
}
func (sp *SPLoginMiddleware) handleError(c *gin.Context, err error) {
if trackedRequestIndex := c.Request.Form.Get("RelayState"); trackedRequestIndex != "" {
_ = sp.requestTracker.StopTrackingRequest(c.Writer, c.Request, trackedRequestIndex)
}
security.MustClear(c)
_ = c.Error(err)
c.Abort()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/csrf"
"github.com/cisco-open/go-lanai/pkg/security/logout"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/security/request_cache"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
type SamlLogoutConfigurer struct {
*samlConfigurer
}
func (c *SamlLogoutConfigurer) Apply(feature security.Feature, ws security.WebSecurity) error {
f := feature.(*Feature)
m := c.makeMiddleware(f, ws)
lh := c.makeLogoutHandler(f, ws)
ep := request_cache.NewSaveRequestEntryPoint(m)
// configure on top of existing logout feature
logout.Configure(ws).
AddLogoutHandler(lh).
AddEntryPoint(ep)
// Add some additional endpoints.
// Note: those endpoints are available regardless what auth method is used, so no condition is applied
ws.Route(matcher.RouteWithPattern(f.sloPath)).
Add(mapping.Get(f.sloPath).
HandlerFunc(m.LogoutHandlerFunc()).
Name("saml slo as sp - get"),
).
Add(mapping.Post(f.sloPath).
HandlerFunc(m.LogoutHandlerFunc()).
Name("saml slo as sp - post"),
).
Add(middleware.NewBuilder("saml idp metadata refresh").
Order(security.MWOrderSAMLMetadataRefresh).
Use(m.RefreshMetadataHandler()),
)
csrf.Configure(ws).
IgnoreCsrfProtectionMatcher(matcher.RequestWithPattern(f.sloPath))
return nil
}
func (c *SamlLogoutConfigurer) makeLogoutHandler(_ *Feature, _ security.WebSecurity) *SingleLogoutHandler {
return NewSingleLogoutHandler()
}
func (c *SamlLogoutConfigurer) makeMiddleware(f *Feature, ws security.WebSecurity) *SPLogoutMiddleware {
opts := c.getServiceProviderConfiguration(f)
sp := c.sharedServiceProvider(opts)
clientManager := c.sharedClientManager(opts)
if f.successHandler == nil {
f.successHandler = request_cache.NewSavedRequestAuthenticationSuccessHandler(
redirect.NewRedirectWithURL("/"),
func(from, to security.Authentication) bool {
return true
},
)
}
return NewLogoutMiddleware(sp, c.idpManager, clientManager, c.effectiveSuccessHandler(f, ws))
}
func newSamlLogoutConfigurer(shared *samlConfigurer) *SamlLogoutConfigurer {
return &SamlLogoutConfigurer{
samlConfigurer: shared,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"net/http"
)
var ErrSamlSloRequired = security.NewAuthenticationError("SAML SLO required")
type SingleLogoutHandler struct{}
func NewSingleLogoutHandler() *SingleLogoutHandler {
return &SingleLogoutHandler{}
}
// ShouldLogout is a logout.ConditionalLogoutHandler method that interrupt logout process by returning authentication error,
// which would trigger authentication entry point and initiate SLO
func (h *SingleLogoutHandler) ShouldLogout(ctx context.Context, _ *http.Request, _ http.ResponseWriter, auth security.Authentication) error {
if !h.requiresSamlSLO(ctx, auth) {
return nil
}
return ErrSamlSloRequired
}
func (h *SingleLogoutHandler) HandleLogout(ctx context.Context, _ *http.Request, _ http.ResponseWriter, auth security.Authentication) error {
if !h.wasSLOFailed(ctx, auth) {
return nil
}
return security.NewAuthenticationWarningError("cisco.saml.logout.failed")
}
func (h *SingleLogoutHandler) samlDetails(_ context.Context, auth security.Authentication) (map[string]interface{}, bool) {
switch v := auth.(type) {
case *samlAssertionAuthentication:
return v.DetailsMap, true
default:
m, _ := auth.Details().(map[string]interface{})
return m, false
}
}
func (h *SingleLogoutHandler) requiresSamlSLO(ctx context.Context, auth security.Authentication) bool {
var isSaml, sloCompleted bool
var details map[string]interface{}
// check if it's saml
details, isSaml = h.samlDetails(ctx, auth)
// check if SLO already completed
state, ok := details[kDetailsSLOState].(SLOState)
sloCompleted = ok && state.Is(SLOCompleted)
return isSaml && !sloCompleted
}
func (h *SingleLogoutHandler) wasSLOFailed(ctx context.Context, auth security.Authentication) bool {
var isSaml, sloFailed bool
var details map[string]interface{}
details, isSaml = h.samlDetails(ctx, auth)
// check if SLO already completed
state, ok := details[kDetailsSLOState].(SLOState)
sloFailed = ok && state.Is(SLOFailed)
return isSaml && sloFailed
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"context"
"encoding/gob"
"errors"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/idp"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/crewjam/saml"
"github.com/gin-gonic/gin"
"net/http"
)
const (
SLOInitiated SLOState = 1 << iota
SLOCompletedFully
SLOCompletedPartially
SLOFailed
SLOCompleted = SLOCompletedFully | SLOCompletedPartially | SLOFailed
)
type SLOState int
func (s SLOState) Is(mask SLOState) bool {
return s&mask != 0 || mask == 0 && s == 0
}
const (
kDetailsSLOState = "SP.SLOState"
)
func init() {
gob.Register(SLOState(0))
}
type SPLogoutMiddleware struct {
SPMetadataMiddleware
successHandler security.AuthenticationSuccessHandler
}
func NewLogoutMiddleware(sp saml.ServiceProvider,
idpManager idp.IdentityProviderManager,
clientManager *CacheableIdpClientManager,
successHandler security.AuthenticationSuccessHandler) *SPLogoutMiddleware {
return &SPLogoutMiddleware{
SPMetadataMiddleware: SPMetadataMiddleware{
internal: sp,
idpManager: idpManager,
clientManager: clientManager,
},
successHandler: successHandler,
}
}
// MakeSingleLogoutRequest initiate SLO at IdP by sending logout request with supported binding
func (m *SPLogoutMiddleware) MakeSingleLogoutRequest(ctx context.Context, r *http.Request, w http.ResponseWriter) error {
// resolve SP client
client, e := m.resolveIdpClient(ctx)
if e != nil {
return e
}
// resolve binding
location, binding := m.resolveBinding(client.GetSLOBindingLocation)
if location == "" {
return security.NewExternalSamlAuthenticationError("idp does not have supported SLO bindings.")
}
// create and send SLO request.
nameId, format := m.resolveNameId(ctx)
// Note 1: MakeLogoutRequest doesn't handle Redirect properly as of crewjam/saml, we wrap it with a temporary fix
// Note 2: SLO specs don't requires RelayState
sloReq, e := samlutils.NewFixedLogoutRequest(client, location, nameId)
if e != nil {
return security.NewExternalSamlAuthenticationError("cannot make SLO request to binding location", e)
}
sloReq.NameID.Format = format
// re-sign the request since we changed the format
sloReq.Signature = nil
if e := client.SignLogoutRequest(&sloReq.LogoutRequest); e != nil {
return security.NewExternalSamlAuthenticationError("cannot sign SLO request", e)
}
switch binding {
case saml.HTTPRedirectBinding:
if e := m.redirectBindingExecutor(sloReq, "", client)(w, r); e != nil {
return security.NewExternalSamlAuthenticationError("cannot send SLO request with HTTP redirect binding", e)
}
case saml.HTTPPostBinding:
if e := m.postBindingExecutor(sloReq, "")(w, r); e != nil {
return security.NewExternalSamlAuthenticationError("cannot send SLO request with HTTP post binding", e)
}
}
return nil
}
// LogoutHandlerFunc returns the handler function that handles LogoutResponse/LogoutRequest sent by IdP.
// This is used to handle response of SP initiated SLO, if it's initiated by us.
// We need to continue our internal logout process
func (m *SPLogoutMiddleware) LogoutHandlerFunc() gin.HandlerFunc {
return func(gc *gin.Context) {
var req saml.LogoutRequest
var resp saml.LogoutResponse
reqR := samlutils.ParseSAMLObject(gc, &req)
respR := samlutils.ParseSAMLObject(gc, &resp)
switch {
case reqR.Err != nil && respR.Err != nil || reqR.Err == nil && respR.Err == nil:
m.handleError(gc, security.NewExternalSamlAuthenticationError("Error reading SAMLRequest/SAMLResponse", reqR.Err, respR.Err))
return
case respR.Err == nil:
m.handleLogoutResponse(gc, &resp, respR.Binding, respR.Encoded)
case reqR.Err == nil:
m.handleLogoutRequest(gc, &req, reqR.Binding, reqR.Encoded)
}
}
}
// Commence implements security.AuthenticationEntryPoint. It's used when SP initiated SLO is required
func (m *SPLogoutMiddleware) Commence(ctx context.Context, r *http.Request, w http.ResponseWriter, err error) {
if !errors.Is(err, ErrSamlSloRequired) {
return
}
logger.WithContext(ctx).Infof("trying to start SAML SP-Initiated SLO")
if e := m.MakeSingleLogoutRequest(ctx, r, w); e != nil {
m.handleError(ctx, e)
return
}
updateSLOState(ctx, func(current SLOState) SLOState {
return current | SLOInitiated
})
}
func (m *SPLogoutMiddleware) handleLogoutResponse(gc *gin.Context, resp *saml.LogoutResponse, binding, encoded string) {
client, ok := m.clientManager.GetClientByEntityId(resp.Issuer.Value)
if !ok {
m.handleError(gc, security.NewExternalSamlAuthenticationError("cannot find idp metadata corresponding for logout response"))
return
}
// perform validate, handle if success
var e error
if binding == saml.HTTPRedirectBinding {
e = client.ValidateLogoutResponseRedirect(encoded)
} else {
e = client.ValidateLogoutResponseForm(encoded)
}
if e == nil {
m.handleSuccess(gc)
return
}
// handle error
m.handleError(gc, e)
}
func (m *SPLogoutMiddleware) handleLogoutRequest(gc *gin.Context, req *saml.LogoutRequest, binding, encoded string) {
// TODO Handle Logout Request for IDP-initiated SLO
}
func (m *SPLogoutMiddleware) resolveIdpClient(ctx context.Context) (*saml.ServiceProvider, error) {
var entityId string
auth := security.Get(ctx)
if samlAuth, ok := auth.(*samlAssertionAuthentication); ok {
entityId = samlAuth.SamlAssertion.Issuer.Value
}
if sp, ok := m.clientManager.GetClientByEntityId(entityId); ok {
return sp, nil
}
return nil, security.NewExternalSamlAuthenticationError("Unable to initiate SLO as SP: unknown SAML Issuer")
}
func (m *SPLogoutMiddleware) resolveNameId(ctx context.Context) (nameId, format string) {
auth := security.Get(ctx)
if samlAuth, ok := auth.(*samlAssertionAuthentication); ok &&
samlAuth.SamlAssertion != nil && samlAuth.SamlAssertion.Subject != nil && samlAuth.SamlAssertion.Subject.NameID != nil {
nameId = samlAuth.SamlAssertion.Subject.NameID.Value
format = samlAuth.SamlAssertion.Subject.NameID.Format
//format = string(saml.EmailAddressNameIDFormat)
}
return
}
func (m *SPLogoutMiddleware) handleSuccess(ctx context.Context) {
updateSLOState(ctx, func(current SLOState) SLOState {
return current | SLOCompletedFully
})
gc := web.GinContext(ctx)
auth := security.Get(ctx)
m.successHandler.HandleAuthenticationSuccess(ctx, gc.Request, gc.Writer, auth, auth)
if gc.Writer.Written() {
gc.Abort()
}
}
func (m *SPLogoutMiddleware) handleError(ctx context.Context, e error) {
logger.WithContext(ctx).Infof("SAML Single Logout failed with error: %v", e)
updateSLOState(ctx, func(current SLOState) SLOState {
return current | SLOFailed
})
// We always let logout continues
gc := web.GinContext(ctx)
auth := security.Get(ctx)
m.successHandler.HandleAuthenticationSuccess(ctx, gc.Request, gc.Writer, auth, auth)
if gc.Writer.Written() {
gc.Abort()
}
}
/***********************
Helper Funcs
***********************/
func currentAuthDetails(ctx context.Context) map[string]interface{} {
auth := security.Get(ctx)
switch m := auth.Details().(type) {
case map[string]interface{}:
return m
default:
return nil
}
}
func currentSLOState(ctx context.Context) SLOState {
details := currentAuthDetails(ctx)
if details == nil {
return 0
}
state, _ := details[kDetailsSLOState].(SLOState)
return state
}
func updateSLOState(ctx context.Context, updater func(current SLOState) SLOState) {
details := currentAuthDetails(ctx)
if details == nil {
return
}
state, _ := details[kDetailsSLOState].(SLOState)
details[kDetailsSLOState] = updater(state)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"encoding/xml"
"github.com/cisco-open/go-lanai/pkg/security/idp"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/crewjam/saml"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
)
var SupportedBindings = utils.NewStringSet(saml.HTTPRedirectBinding, saml.HTTPPostBinding)
// SPMetadataMiddleware
// A SAML service provider should be able to work with multiple identity providers.
// Because the saml package assumes a service provider is configured with one idp only,
// we use the internal field to store information about this service provider,
// and we will create new saml.ServiceProvider struct for each new idp connection when its needed.
type SPMetadataMiddleware struct {
// using value instead of pointer here because we need to copy it when connecting to specific idps.
// the methods on saml.ServiceProvider are actually pointer receivers. golang will implicitly use
// the pointers to these value as receivers
internal saml.ServiceProvider
idpManager idp.IdentityProviderManager
clientManager *CacheableIdpClientManager
}
// MetadataHandlerFunc endpoint that provide SP's metadata
func (m *SPMetadataMiddleware) MetadataHandlerFunc() gin.HandlerFunc {
return func(c *gin.Context) {
//do this because the refresh metadata middleware is conditional,
//but the metadata endpoint is not conditional
m.refreshMetadata(c)
descriptor := m.internal.Metadata()
var mergedAcs []saml.IndexedEndpoint
var mergedSlo []saml.Endpoint
//we only provide ACS and SLO for the domains we configured
for i, delegate := range m.clientManager.GetAllClients() {
// ACS
delegateDescriptor := delegate.Metadata().SPSSODescriptors[0]
delegateAcs := delegateDescriptor.AssertionConsumerServices[0]
delegateAcs.Index = i
mergedAcs = append(mergedAcs, delegateAcs)
// SLO
delegateSlo := delegateDescriptor.SingleLogoutServices
mergedSlo = append(mergedSlo, delegateSlo...)
}
descriptor.SPSSODescriptors[0].AssertionConsumerServices = mergedAcs
descriptor.SPSSODescriptors[0].SingleLogoutServices = mergedSlo
w := c.Writer
buf, _ := xml.MarshalIndent(descriptor, "", " ")
w.Header().Set("Content-LoggerType", "application/samlmetadata+xml")
w.Header().Set("Content-Disposition", "attachment; filename=metadata.xml")
_, _ = w.Write(buf)
}
}
// RefreshMetadataHandler MW that responsible to refresh IDP's metadata whenever SAML Login/Logout related endpoint is called
func (m *SPMetadataMiddleware) RefreshMetadataHandler() gin.HandlerFunc {
return m.refreshMetadata
}
// cache that are populated by the refresh metadata middleware instead of populated dynamically on commence
// because in a multi-instance micro service deployment, the auth request and auth response can occur on
// different instance
func (m *SPMetadataMiddleware) refreshMetadata(c *gin.Context) {
idpDetails := m.idpManager.GetIdentityProvidersWithFlow(c.Request.Context(), idp.ExternalIdpSAML)
var samlIdpDetails []samlctx.SamlIdentityProvider
for _, i := range idpDetails {
if s, ok := i.(samlctx.SamlIdentityProvider); ok {
samlIdpDetails = append(samlIdpDetails, s)
}
}
m.clientManager.RefreshCache(c, samlIdpDetails)
}
// resolveBinding find first supported binding using given binding location extractor
func (m *SPMetadataMiddleware) resolveBinding(extractor func(string) string) (location, binding string) {
bindings := []string{saml.HTTPRedirectBinding, saml.HTTPPostBinding}
if manager, ok := m.idpManager.(samlctx.SamlBindingManager); ok {
bindings = manager.PreferredBindings()
}
for _, b := range bindings {
location = extractor(b)
if location != "" && SupportedBindings.Has(b) {
binding = b
return
}
}
return "", ""
}
// bindableSamlRequest abstracted interface that both saml.AuthnRequest and FixedLogoutRequest implements
type bindableSamlRequest interface {
Redirect(relayState string, sp *saml.ServiceProvider) (*url.URL, error)
Post(relayState string) []byte
}
func (m *SPMetadataMiddleware) redirectBindingExecutor(req bindableSamlRequest, relayState string, sp *saml.ServiceProvider) func(w http.ResponseWriter, r *http.Request) error {
return func(w http.ResponseWriter, r *http.Request) error {
redirectURL, e := req.Redirect(relayState, sp)
if e != nil {
return e
}
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
_, _ = w.Write(nil)
return nil
}
}
func (m *SPMetadataMiddleware) postBindingExecutor(req bindableSamlRequest, relayState string) func(w http.ResponseWriter, r *http.Request) error {
return func(w http.ResponseWriter, r *http.Request) error {
data := req.Post(relayState)
return samlutils.WritePostBindingHTML(data, w)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"encoding/gob"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/idp"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
)
var logger = log.New("SAML.Auth")
var Module = &bootstrap.Module{
Name: "saml authenticator",
Precedence: security.MinSecurityPrecedence + 30,
Options: []fx.Option{
fx.Invoke(register),
},
}
func init() {
gob.Register((*samlAssertionAuthentication)(nil))
}
func Use() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
SamlProperties samlctx.SamlProperties
ServerProps web.ServerProperties
IdpManager idp.IdentityProviderManager
AccountStore security.FederatedAccountStore
}
func register(di initDI) {
if di.SecRegistrar != nil {
shared := newSamlConfigurer(di.SamlProperties, di.IdpManager)
loginConfigurer := newSamlAuthConfigurer(shared, di.AccountStore)
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, loginConfigurer)
logoutConfigurer := newSamlLogoutConfigurer(shared)
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(LogoutFeatureId, logoutConfigurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sp
import (
"encoding/base64"
"fmt"
"github.com/crewjam/saml/samlsp"
"github.com/google/uuid"
"net/http"
"strings"
"time"
)
//This implementation is similar to that found in samlsp.CookieRequestTracker
//However, we don't need a reference to a ServiceProvider instance
//and we set the cookie's Secure and SsoPath attribute explicitly
//and we let the cookie's domain be determined by the request itself.
//This is because our tracker needs to work with multiple ServiceProvider instances each talking to a different idp.
// CookieRequestTracker tracks requests by setting a uniquely named
// cookie for each request.
type CookieRequestTracker struct {
NamePrefix string
Codec samlsp.TrackedRequestCodec
MaxAge time.Duration
SameSite http.SameSite
Secure bool
Path string
}
// TrackRequest starts tracking the SAML request with the given ID. It returns an
// `index` that should be used as the RelayState in the SAMl request flow.
func (t CookieRequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (string, error) {
trackedRequest := samlsp.TrackedRequest{
Index: base64.RawURLEncoding.EncodeToString([]byte(uuid.New().String())),
SAMLRequestID: samlRequestID,
URI: r.URL.String(),
}
signedTrackedRequest, err := t.Codec.Encode(trackedRequest)
if err != nil {
return "", err
}
http.SetCookie(w, &http.Cookie{
Name: t.NamePrefix + trackedRequest.Index,
Value: signedTrackedRequest,
MaxAge: int(t.MaxAge.Seconds()),
HttpOnly: true,
SameSite: t.SameSite,
Secure: t.Secure,
Path: t.Path,
})
return trackedRequest.Index, nil
}
// StopTrackingRequest stops tracking the SAML request given by index, which is a string
// previously returned from TrackRequest
func (t CookieRequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error {
cookie, err := r.Cookie(t.NamePrefix + index)
if err != nil {
return err
}
cookie.Value = ""
cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{}
cookie.Path = t.Path
cookie.Secure = t.Secure
cookie.SameSite = t.SameSite
cookie.HttpOnly = true
http.SetCookie(w, cookie)
return nil
}
// GetTrackedRequests returns all the pending tracked requests
func (t CookieRequestTracker) GetTrackedRequests(r *http.Request) []samlsp.TrackedRequest {
var rv []samlsp.TrackedRequest
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, t.NamePrefix) {
continue
}
trackedRequest, err := t.Codec.Decode(cookie.Value)
if err != nil {
continue
}
index := strings.TrimPrefix(cookie.Name, t.NamePrefix)
if index != trackedRequest.Index {
continue
}
rv = append(rv, *trackedRequest)
}
return rv
}
// GetTrackedRequest returns a pending tracked request.
func (t CookieRequestTracker) GetTrackedRequest(r *http.Request, index string) (*samlsp.TrackedRequest, error) {
cookie, err := r.Cookie(t.NamePrefix + index)
if err != nil {
return nil, err
}
trackedRequest, err := t.Codec.Decode(cookie.Value)
if err != nil {
return nil, err
}
if trackedRequest.Index != index {
return nil, fmt.Errorf("expected index %q, got %q", index, trackedRequest.Index)
}
return trackedRequest, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"bytes"
"compress/flate"
"encoding/base64"
"fmt"
"github.com/beevik/etree"
lanaisaml "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/crewjam/saml"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
)
type ActualSamlRequest struct {
XMLDoc *etree.Document
Location string
RelayState string
SigAlg string
Signature string
}
type SamlRequestMatcher struct {
SamlProperties lanaisaml.SamlProperties
Binding string
Subject string
ExpectedMsg string
}
func (a SamlRequestMatcher) Extract(actual interface{}) (*ActualSamlRequest, error) {
switch a.Binding {
case saml.HTTPPostBinding:
return a.extractPost(actual)
case saml.HTTPRedirectBinding:
return a.extractRedirect(actual)
default:
return nil, fmt.Errorf("unable to verify %s with binding '%s'", a.Subject, a.Binding)
}
}
func (a SamlRequestMatcher) extractPost(actual interface{}) (*ActualSamlRequest, error) {
w := actual.(*httptest.ResponseRecorder)
html := etree.NewDocument()
if _, e := html.ReadFrom(w.Body); e != nil {
return nil, e
}
formElem := html.FindElement("//form[@action]")
if formElem == nil {
return nil, fmt.Errorf("form with is not found in HTML")
}
reqElem := html.FindElement("//input[@name='SAMLRequest']")
if reqElem == nil {
return nil, fmt.Errorf("form doesn't contain 'SAMLRequest' value in HTML")
}
reqDecoded, e := base64.StdEncoding.DecodeString(reqElem.SelectAttrValue("value", ""))
if e != nil {
return nil, e
}
req := ActualSamlRequest{
XMLDoc: etree.NewDocument(),
Location: formElem.SelectAttrValue("action", ""),
}
if e := req.XMLDoc.ReadFromBytes(reqDecoded); e != nil {
return nil, e
}
if elem := html.FindElement("//input[@name='RelayState']"); elem != nil {
req.RelayState = elem.SelectAttrValue("value", "")
}
if elem := req.XMLDoc.FindElement("//ds:SignatureMethod"); elem != nil {
req.SigAlg = elem.SelectAttrValue("Algorithm", "")
}
if elem := req.XMLDoc.FindElement("//ds:SignatureValue"); elem != nil {
req.Signature = elem.Text()
}
return &req, nil
}
func (a SamlRequestMatcher) extractRedirect(actual interface{}) (*ActualSamlRequest, error) {
var resp *http.Response
switch v := actual.(type) {
case *httptest.ResponseRecorder:
resp = v.Result()
case *http.Response:
resp = v
}
if resp.StatusCode < 300 || resp.StatusCode > 399 {
return nil, fmt.Errorf("not redirect")
}
loc := resp.Header.Get("Location")
locUrl, e := url.Parse(loc)
if e != nil {
return nil, e
}
loc = loc[:strings.IndexRune(loc, '?')]
// Note redirect request is compressed
compressed, e := base64.StdEncoding.DecodeString(locUrl.Query().Get("SAMLRequest"))
if e != nil {
return nil, e
}
r := flate.NewReader(bytes.NewReader(compressed))
defer func() { _ = r.Close() }()
reqDecoded, e := io.ReadAll(r)
if e != nil {
return nil, e
}
req := ActualSamlRequest{
XMLDoc: etree.NewDocument(),
Location: loc,
RelayState: locUrl.Query().Get("RelayState"),
SigAlg: locUrl.Query().Get("SigAlg"),
Signature: locUrl.Query().Get("Signature"),
}
if e := req.XMLDoc.ReadFromBytes(reqDecoded); e != nil {
return nil, e
}
return &req, nil
}
func (a SamlRequestMatcher) FailureMessage(actual interface{}) (message string) {
w := actual.(*httptest.ResponseRecorder)
body := string(w.Body.Bytes())
return fmt.Sprintf("Expected %s as %s. Actual: %s", a.Subject, a.ExpectedMsg, body)
}
func (a SamlRequestMatcher) NegatedFailureMessage(actual interface{}) (message string) {
w := actual.(*httptest.ResponseRecorder)
body := string(w.Body.Bytes())
return fmt.Sprintf("Expected %s as %s. Actual: %s", a.Subject, a.ExpectedMsg, body)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlutils
import (
"bytes"
"compress/flate"
"encoding/base64"
"github.com/beevik/etree"
"github.com/crewjam/saml"
"net/url"
"strings"
)
// Redirect this is copied from saml.AuthnRequest.Redirect.
// As of crewjam/saml 0.4.8, crewjam/saml made an attempt of fixing saml.AuthnRequest.Redirect with correct Signature:
// https://github.com/crewjam/saml/pull/339
// However, per SAML 2.0 Binding protocol specs, the signing should only apply to query "SAMLRequest=value&RelayState=value&SigAlg=value",
// but crewjam/saml 0.4.8 uses all query string for signing.
// See https://www.oasis-open.org/committees/download.php/35387/sstc-saml-bindings-errata-2.0-wd-05-diff.pdf
// TODO revisit this part later when newer crewjam/saml library become available
func redirectUrl(relayState string, sp *saml.ServiceProvider, rootEl *etree.Element, dest string) (*url.URL, error) {
w := &bytes.Buffer{}
w1 := base64.NewEncoder(base64.StdEncoding, w)
w2, _ := flate.NewWriter(w1, 9)
doc := etree.NewDocument()
doc.SetRoot(rootEl)
if _, err := doc.WriteTo(w2); err != nil {
return nil, err
}
_ = w2.Close()
_ = w1.Close()
rawKVs := make([]string, 1, 3)
rv, _ := url.Parse(dest)
query := rv.Query()
rawKVs[0] = HttpParamSAMLRequest + "=" + url.QueryEscape(string(w.Bytes()))
query.Set(HttpParamSAMLRequest, string(w.Bytes()))
if relayState != "" {
rawKVs = append(rawKVs, HttpParamRelayState+ "=" + url.QueryEscape(relayState))
query.Set(HttpParamRelayState, relayState)
}
if len(sp.SignatureMethod) > 0 {
rawKVs = append(rawKVs, HttpParamSigAlg+ "=" + url.QueryEscape(sp.SignatureMethod))
query.Set(HttpParamSigAlg, sp.SignatureMethod)
signingContext, e := saml.GetSigningContext(sp)
if e != nil {
return nil, e
}
sig, e := signingContext.SignString(strings.Join(rawKVs, "&"))
if e != nil {
return nil, e
}
sigVal := base64.StdEncoding.EncodeToString(sig)
query.Set(HttpParamSignature, sigVal)
}
rv.RawQuery = query.Encode()
return rv, nil
}
/***********************
AuthnRequest
***********************/
type FixedAuthnRequest struct {
saml.AuthnRequest
}
func NewFixedAuthenticationRequest(sp *saml.ServiceProvider, idpURL string, binding string, resultBinding string) (*FixedAuthnRequest, error) {
req, e := sp.MakeAuthenticationRequest(idpURL, binding, resultBinding)
if e != nil {
return nil, e
}
return &FixedAuthnRequest{*req}, nil
}
// Redirect crewjam/saml 0.4.8 hotfix.
func (req *FixedAuthnRequest) Redirect(relayState string, sp *saml.ServiceProvider) (*url.URL, error) {
// per SAML 2.0 spec, Signature element should be removed from xml in case of redirect binding
req.Signature = nil
return redirectUrl(relayState, sp, req.Element(), req.Destination)
}
/***********************
LogoutRequest
***********************/
type FixedLogoutRequest struct {
saml.LogoutRequest
}
func NewFixedLogoutRequest(sp *saml.ServiceProvider, idpURL, nameID string) (*FixedLogoutRequest, error) {
req, e := sp.MakeLogoutRequest(idpURL, nameID)
if e != nil {
return nil, e
}
return &FixedLogoutRequest{*req}, nil
}
// Redirect crewjam/saml 0.4.8 hotfix.
func (req *FixedLogoutRequest) Redirect(relayState string, sp *saml.ServiceProvider) (*url.URL, error) {
// per SAML 2.0 spec, Signature element should be removed from xml in case of redirect binding
req.Signature = nil
return redirectUrl(relayState, sp, req.Element(), req.Destination)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlutils
import (
"context"
"github.com/crewjam/httperr"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"io"
"net/http"
"net/url"
"os"
"strings"
)
type MetadataResolverOptions func(opt *MetadataResolverOption)
type MetadataResolverOption struct {
HttpClient *http.Client
}
func WithHttpClient(client *http.Client) MetadataResolverOptions {
return func(opt *MetadataResolverOption) {
opt.HttpClient = client
}
}
// ResolveMetadata try to resolve metadata from given metadata source
// Following modes are supported
// - if the source start with "<", it's treated as inline XML text
// - if the source is a valid HTTP/HTTPS URL, metadata is fetched over network using http.Client
// - if the source is a valid FILE URL (file://), metadata is loaded from file system
// - for any other source value, it's treated as file path
func ResolveMetadata(ctx context.Context, metadataSource string, opts...MetadataResolverOptions) (*saml.EntityDescriptor, []byte, error) {
opt := MetadataResolverOption{
HttpClient: http.DefaultClient,
}
for _, fn := range opts {
fn(&opt)
}
if strings.HasPrefix(metadataSource, "<") {
return ParseMetadataFromXml(metadataSource)
}
metadataUrl, err := url.Parse(metadataSource)
if err != nil {
return nil, nil, err
}
//if it's not url or file url, assume it's relative path
if metadataUrl.Scheme == "file" || metadataUrl.Scheme == "" {
return ParseMetadataFromFile(metadataUrl.Path)
} else {
return FetchMetadata(ctx, opt.HttpClient, metadataUrl)
}
}
func ParseMetadataFromXml(xml string) (*saml.EntityDescriptor, []byte, error) {
data := []byte(xml)
metadata, err := samlsp.ParseMetadata(data)
return metadata, data, err
}
func ParseMetadataFromFile(fileLocation string) (*saml.EntityDescriptor, []byte, error) {
file, err := os.Open(fileLocation)
if err != nil {
return nil, nil, err
}
data, err := io.ReadAll(file)
if err != nil {
return nil, nil, err
}
metadata, err := samlsp.ParseMetadata(data)
return metadata, data, err
}
func FetchMetadata(ctx context.Context, httpClient *http.Client, metadataURL *url.URL) (*saml.EntityDescriptor, []byte, error) {
req, err := http.NewRequest("GET", metadataURL.String(), nil)
if err != nil {
return nil, nil, err
}
req = req.WithContext(ctx)
resp, err := httpClient.Do(req)
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return nil, nil, httperr.Response(*resp)
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, data, err
}
metadata, err := samlsp.ParseMetadata(data)
return metadata, data, err
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlutils
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"fmt"
"github.com/beevik/etree"
"github.com/russellhaering/goxmldsig/etreeutils"
"golang.org/x/net/html"
"net/http"
"strings"
)
// FindChild search direct child XML element matching given NS and Tag in the given parent element
func FindChild(parentEl *etree.Element, childNS string, childTag string) (*etree.Element, error) {
for _, childEl := range parentEl.ChildElements() {
if childEl.Tag != childTag {
continue
}
ctx, err := etreeutils.NSBuildParentContext(childEl)
if err != nil {
return nil, err
}
ctx, err = ctx.SubContext(childEl)
if err != nil {
return nil, err
}
ns, err := ctx.LookupPrefix(childEl.Space)
if err != nil {
return nil, fmt.Errorf("[%s]:%s cannot find prefix %s: %v", childNS, childTag, childEl.Space, err)
}
if ns != childNS {
continue
}
return childEl, nil
}
return nil, nil
}
// WritePostBindingHTML takes HTML of a request/response submitting form and wrap it in HTML document with proper
// script security tags and send it to given ResponseWriter
func WritePostBindingHTML(formHtml []byte, rw http.ResponseWriter) error {
body := []byte(fmt.Sprintf(`<!DOCTYPE html><html><body>%s</body></html>`, formHtml))
csp := fmt.Sprintf("default-src; script-src %s; reflected-xss block; referrer no-referrer;", scriptSrcHash(body))
rw.Header().Add("Content-Type", "text/html")
rw.Header().Add("Content-Security-Policy", csp)
_, e := rw.Write(body)
return e
}
// scriptSrcHash returns '<hash-algorithm>-<base64-value>' of all inline <script></script> found in given html, delimited by space
// See CSP specs
func scriptSrcHash(htmlBytes []byte) string {
const fallback = `'unsafe-inline'`
root, e := html.Parse(bytes.NewReader(htmlBytes))
if e != nil {
return fallback
}
scripts := findAllHtmlNodes(root, func(node *html.Node) bool {
return node.Type == html.TextNode && node.Parent != nil && node.Parent.Data == "script"
})
srcs := make([]string, len(scripts))
for i, node := range scripts {
hash := sha256.Sum256([]byte(node.Data))
srcs[i] = fmt.Sprintf("'sha256-%s'", base64.StdEncoding.EncodeToString(hash[:]))
}
return strings.Join(srcs, " ")
}
func findAllHtmlNodes(node *html.Node, matcher func(*html.Node) bool) (found []*html.Node) {
for child := node.FirstChild; child != nil; child = child.NextSibling {
if matcher(child) {
found = append(found, child)
}
if sub := findAllHtmlNodes(child, matcher); len(sub) != 0 {
found = append(found, sub...)
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlutils
import (
"bytes"
"compress/flate"
"encoding/base64"
"encoding/xml"
"fmt"
"github.com/crewjam/saml"
"github.com/gin-gonic/gin"
"io"
)
type ParsableSamlTypes interface {
saml.LogoutRequest | saml.LogoutResponse | saml.AuthnRequest | saml.Response
}
type SAMLObjectParseResult struct {
Binding string
Encoded string
Decoded []byte
Err error
}
func ParseSAMLObject[T ParsableSamlTypes](gc *gin.Context, dest *T) (ret SAMLObjectParseResult) {
param := HttpParamSAMLResponse
var i interface{} = dest
switch i.(type) {
case *saml.LogoutRequest, *saml.AuthnRequest:
param = HttpParamSAMLRequest
}
ret.Binding = saml.HTTPRedirectBinding
if ret.Encoded, _ = gc.GetQuery(param); len(ret.Encoded) == 0 {
ret.Encoded = gc.PostForm(param)
ret.Binding = saml.HTTPPostBinding
}
if len(ret.Encoded) == 0 {
ret.Err = fmt.Errorf("unable to find %s in http request", param)
return
}
ret.Decoded, ret.Err = base64.StdEncoding.DecodeString(ret.Encoded)
if ret.Err != nil {
return
}
// try de-compress
r := flate.NewReader(bytes.NewReader(ret.Decoded))
if data, e := io.ReadAll(r); e == nil {
ret.Decoded = data
}
ret.Err = xml.Unmarshal(ret.Decoded, dest)
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samlutils
import (
"crypto"
"crypto/rsa"
//nolint:gosec // weak cryptographic primitive, but we still need to support it
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"github.com/beevik/etree"
"github.com/crewjam/saml"
dsig "github.com/russellhaering/goxmldsig"
"github.com/russellhaering/goxmldsig/etreeutils"
"hash"
"net/http"
"strings"
)
type SignatureVerifyOptions func(sc *SignatureContext)
type SignatureContext struct {
Binding string
Certs []*x509.Certificate
Request *http.Request
XMLData []byte
}
// MetadataSignature returns SignatureVerifyOptions for metadata validation
func MetadataSignature(data []byte, certs ...*x509.Certificate) SignatureVerifyOptions {
return func(sc *SignatureContext) {
sc.Binding = ""
sc.Certs = certs
sc.XMLData = data
}
}
// VerifySignature verify signature of SAML Request/Response/Metadata
// This function would choose signing protocol based on bindings and provided information.
// - saml.HTTPRedirectBinding uses Deflated Encoding. SignatureContext.Request and SignatureContext.Certs is required in this mode
// - saml.HTTPPostBinding uses enveloped XMLDSign. SignatureContext.XMLData is required in this mode
// - Enveloped XMLDSign is used when Binding is any other value. SignatureContext.XMLData is required in this mode
func VerifySignature(opts...SignatureVerifyOptions) error {
sc := SignatureContext{}
for _, fn := range opts {
fn(&sc)
}
switch sc.Binding {
case saml.HTTPRedirectBinding:
return verifyDeflateEncSign(&sc)
default:
return verifyXMLDSign(&sc)
}
}
// verifyXMLDSign validate Enveloped XMLDSign signature, typically used for PostBinding or Metadata
func verifyXMLDSign(sc *SignatureContext) error {
if len(sc.XMLData) == 0 {
return errors.New("XML document is missing for signature verification")
}
doc := etree.NewDocument()
if err := doc.ReadFromBytes(sc.XMLData); err != nil {
return errors.New("error parsing XML document for signature verification")
}
el := doc.Root()
sigEl, e := FindChild(el, "http://www.w3.org/2000/09/xmldsig#", "Signature")
if e != nil || sigEl == nil {
return ErrorXMLNotSigned
}
certificateStore := dsig.MemoryX509CertificateStore{
Roots: sc.Certs,
}
validationContext := dsig.NewDefaultValidationContext(&certificateStore)
validationContext.IdAttribute = "ID"
if saml.Clock != nil {
validationContext.Clock = saml.Clock
}
//if there's signature but keyInfo is not X509, then we remove the key info element, and just use the
//default public key to verify.
//if keyinfo is x509, it'll be verified that it's a trusted key before being used to verify the signature
//See the logic in validationContext.Validate
if el.FindElement("./Signature/KeyInfo/X509Data/X509Certificate") == nil {
if keyInfo := sigEl.FindElement("KeyInfo"); keyInfo != nil {
sigEl.RemoveChild(keyInfo)
}
}
ctx, e := etreeutils.NSBuildParentContext(el)
if e != nil {
return errors.New("error getting document context for signature check")
}
if ctx, e = ctx.SubContext(el); e != nil {
return errors.New("error getting document sub context for signature check")
}
//makes a copy of the element
if el, e = etreeutils.NSDetatch(ctx, el); e != nil {
return errors.New("error getting document for signature check")
}
if _, e = validationContext.Validate(el); e != nil {
return errors.New("invalid signature")
}
return nil
}
// verifyDeflateEncSign validate DEFLATE URL encoding signature, typically used for RedirectBinding of SAML Request
// https://www.oasis-open.org/committees/download.php/35387/sstc-saml-bindings-errata-2.0-wd-05-diff.pdf
func verifyDeflateEncSign(sc *SignatureContext) error {
// some sanity check
if sc.Request == nil {
return fmt.Errorf("HTTP Request is required for DEFLATE Encoding signature verification")
}
if enc := queryValue(sc, HttpParamSAMLEncoding); len(enc) != 0 && enc != SAMLEncodingDeflate {
return fmt.Errorf("unsupported SAML encoding [%s]", enc)
}
// find signature
alg := queryValue(sc, HttpParamSigAlg)
encodedSig := queryValue(sc, HttpParamSignature)
if len(alg) == 0 || len(encodedSig) == 0 {
return ErrorXMLNotSigned
}
sig, e := base64.StdEncoding.DecodeString(encodedSig)
if e != nil || len(sig) == 0 {
return fmt.Errorf("failed to decode signature")
}
// extract to-be-verified data
toVerify := toBeVerifiedQuery(sc)
// verify
var err error
for _, cert := range sc.Certs {
if err = rsaVerify([]byte(toVerify), sig, cert.PublicKey, alg); err == nil {
return nil
}
}
return err
}
func toBeVerifiedQuery(sc *SignatureContext) string {
// SAMLRequest=value&RelayState=value&SigAlg=value
// SAMLResponse=value&RelayState=value&SigAlg=value
// note: per SAML spec, we need to use the original URL encoded query instead of re-encoding the query
candidates := make([]string, 3)
rawKVs := strings.Split(sc.Request.URL.RawQuery, "&")
for _, pair := range rawKVs {
kv := strings.SplitN(pair, "=", 2)
var i int
switch kv[0] {
case HttpParamSAMLRequest, HttpParamSAMLResponse:
i = 0
case HttpParamRelayState:
i = 1
case HttpParamSigAlg:
i = 2
default:
continue
}
candidates[i] = pair
}
toVerify := make([]string, 0, len(candidates))
for _, v := range candidates {
if len(v) != 0 {
toVerify = append(toVerify, v)
}
}
return strings.Join(toVerify, "&")
}
func rsaVerify(data, signature []byte, pubKey any, method string) error {
var h hash.Hash
var hashAlg crypto.Hash
switch method {
case dsig.RSASHA1SignatureMethod:
//nolint:gosec // weak cryptographic primitive, but we still need to support it
h = sha1.New()
hashAlg = crypto.SHA1
case dsig.RSASHA256SignatureMethod:
h = sha256.New()
hashAlg = crypto.SHA256
case dsig.RSASHA512SignatureMethod:
h = sha512.New()
hashAlg = crypto.SHA512
default:
return fmt.Errorf("unsupported signature method: %s", method)
}
_, _ = h.Write(data)
hashed := h.Sum(nil)
rsaPubKey, ok := pubKey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("RSA public key is required to verify signature")
}
return rsa.VerifyPKCS1v15(rsaPubKey, hashAlg, hashed, signature)
}
func queryValue(sc *SignatureContext, key string) string {
return sc.Request.URL.Query().Get(key)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package common
import (
"fmt"
"time"
)
const RedisNameSpace = "LANAI:SESSION" //This is to avoid confusion with records from other frameworks.
const SessionLastAccessedField = "lastAccessed"
const SessionIdleTimeoutDuration = "idleTimeout"
const SessionAbsTimeoutTime = "absTimeout"
const DefaultName = "SESSION"
type TimeoutSetting int
const (
IdleTimeoutEnabled TimeoutSetting = 1 << iota
AbsoluteTimeoutEnabled TimeoutSetting = 1 << iota
)
func GetRedisSessionKey(name string, id string) string {
return fmt.Sprintf("%s:%s:%s", RedisNameSpace, name, id)
}
func CalculateExpiration(setting TimeoutSetting, idleExpiration time.Time, absExpiration time.Time) (canExpire bool, expiration time.Time) {
switch setting {
case AbsoluteTimeoutEnabled:
return true, absExpiration
case IdleTimeoutEnabled:
return true, idleExpiration
case AbsoluteTimeoutEnabled | IdleTimeoutEnabled:
//whichever is the earliest
if idleExpiration.Before(absExpiration) {
return true, idleExpiration
} else {
return true, absExpiration
}
default:
return false, time.Time{}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package session
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/session/common"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
var (
FeatureId = security.FeatureId("Session", security.FeatureOrderSession)
)
// Feature holds session configuration
type Feature struct {
sessionName string
settingService SettingService
}
func (f *Feature) Identifier() security.FeatureIdentifier {
return FeatureId
}
func (f *Feature) SettingService(settingService SettingService) *Feature {
f.settingService = settingService
return f
}
func (f *Feature) SessionName(sessionName string) *Feature {
f.sessionName = sessionName
return f
}
// Configure Standard security.Feature entrypoint
func Configure(ws security.WebSecurity) *Feature {
feature := New()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*Feature)
}
panic(fmt.Errorf("unable to configure session: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
// New Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func New() *Feature {
return &Feature{
sessionName: common.DefaultName,
}
}
type Configurer struct {
store Store
sessionProps security.SessionProperties
}
func newSessionConfigurer(sessionProps security.SessionProperties, sessionStore Store) *Configurer {
return &Configurer{
store: sessionStore,
sessionProps: sessionProps,
}
}
func (sc *Configurer) Apply(feature security.Feature, ws security.WebSecurity) error {
f := feature.(*Feature)
if len(f.sessionName) == 0 {
f.sessionName = common.DefaultName
}
// the ws shared store is to share this store with other feature configurer can have access to store.
if ws.Shared(security.WSSharedKeySessionStore) == nil {
_ = ws.AddShared(security.WSSharedKeySessionStore, sc.store)
}
// configure middleware
manager := NewManager(f.sessionName, sc.store)
sessionHandler := middleware.NewBuilder("sessionMiddleware").
Order(security.MWOrderSessionHandling).
Use(manager.SessionHandlerFunc())
authPersist := middleware.NewBuilder("authPersistMiddleware").
Order(security.MWOrderAuthPersistence).
Use(manager.AuthenticationPersistenceHandlerFunc())
//test := middleware.NewBuilder("post-sessionMiddleware").
// Order(security.MWOrderAuthPersistence + 10).
// Use(SessionDebugHandlerFunc())
ws.Add(sessionHandler, authPersist)
// configure auth success/error handler
ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(*security.CompositeAuthenticationSuccessHandler).
Add(&ChangeSessionHandler{})
if bootstrap.DebugEnabled() {
ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(*security.CompositeAuthenticationSuccessHandler).
Add(&DebugAuthSuccessHandler{})
ws.Shared(security.WSSharedKeyCompositeAuthErrorHandler).(*security.CompositeAuthenticationErrorHandler).
Add(&DebugAuthErrorHandler{})
}
var settingService SettingService
if f.settingService == nil {
settingService = NewDefaultSettingService(sc.sessionProps)
} else {
settingService = f.settingService
}
concurrentSessionHandler := &ConcurrentSessionHandler{
sessionStore: sc.store,
sessionSettingService: settingService,
}
ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(*security.CompositeAuthenticationSuccessHandler).
Add(concurrentSessionHandler)
deleteSessionHandler := &DeleteSessionOnLogoutHandler{
sessionStore: sc.store,
}
ws.Shared(security.WSSharedKeyCompositeAuthSuccessHandler).(*security.CompositeAuthenticationSuccessHandler).
Add(deleteSessionHandler)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package session
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"net/http"
)
type DebugAuthSuccessHandler struct {}
func (h *DebugAuthSuccessHandler) HandleAuthenticationSuccess(
_ context.Context, _ *http.Request, _ http.ResponseWriter, from, to security.Authentication) {
logger.Debugf("session knows auth succeeded: from [%v] to [%v]", from, to)
}
type DebugAuthErrorHandler struct {}
func (h *DebugAuthErrorHandler) HandleAuthenticationError(_ context.Context, _ *http.Request, _ http.ResponseWriter, err error) {
logger.Debugf("session knows auth failed with %v", err.Error())
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package session
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"net/http"
"sort"
)
// ChangeSessionHandler
/**
This is a high priority handler because it writes to the header.
Therefore, it must be before any other success handler that may write the response status (e.g. redirect handler)
*/
type ChangeSessionHandler struct{}
func (h *ChangeSessionHandler) HandleAuthenticationSuccess(c context.Context, r *http.Request, rw http.ResponseWriter, from, to security.Authentication) {
if !security.IsBeingAuthenticated(from, to) {
return
}
s := Get(c)
if s == nil {
return
}
//if this is a new session that hasn't been saved, then we don't need to change it
if s.isNew {
return
}
err := s.ChangeId()
if err == nil {
http.SetCookie(rw, NewCookie(s.Name(), s.id, s.options, r))
} else {
panic(security.NewInternalError("Failed to update session ID", err))
}
}
func (h *ChangeSessionHandler) PriorityOrder() int {
return security.HandlerOrderChangeSession
}
// ConcurrentSessionHandler This handler runs after ChangeSessionHandler so that the updated session id is indexed to the principal
type ConcurrentSessionHandler struct{
sessionStore Store
sessionSettingService SettingService
}
func (h *ConcurrentSessionHandler) HandleAuthenticationSuccess(c context.Context, _ *http.Request, _ http.ResponseWriter, from, to security.Authentication) {
if !security.IsBeingAuthenticated(from, to) {
return
}
s := Get(c)
if s == nil {
return
}
p, err := security.GetUsername(to)
if err != nil {
//Auth is something we don't recognize, this indicates a program error
panic(security.NewInternalError(err.Error()))
}
//Adding to the index before checking the limit.
//If done other way around, concurrent logins may be doing the check before the other request added to the index
//thus making it possible to exceed the limit
//By doing the check at the end, we can end up with the right number of sessions when all requests finishes.
err = h.sessionStore.WithContext(c).AddToPrincipalIndex(p, s)
if err != nil {
panic(security.NewInternalError(err.Error()))
}
sessionName := s.Name()
//This will also clean the expired sessions from the index, so we do it regardless if max sessions is set or not
existing, err := h.sessionStore.WithContext(c).FindByPrincipalName(p, sessionName)
if err != nil {
panic(security.NewInternalError(err.Error()))
}
max := h.sessionSettingService.GetMaximumSessions(c)
if len(existing) <= max || max <= 0 {
return
}
sort.SliceStable(existing, func(i, j int) bool {
return existing[i].createdOn().Before(existing[j].createdOn())
})
if e := h.sessionStore.WithContext(c).Invalidate(existing[:len(existing) - max]...); e != nil {
panic(security.NewInternalError("Cannot delete session that exceeded max concurrent session limit"))
}
}
func (h *ConcurrentSessionHandler) PriorityOrder() int {
return security.HandlerOrderConcurrentSession
}
type DeleteSessionOnLogoutHandler struct {
sessionStore Store
}
func (h *DeleteSessionOnLogoutHandler) HandleAuthenticationSuccess(c context.Context, _ *http.Request, _ http.ResponseWriter, from, to security.Authentication) {
if !security.IsBeingUnAuthenticated(from, to) {
return
}
s := Get(c)
defer func() {
MustSet(c, nil)
}()
if s == nil {
return
}
if e := h.sessionStore.Invalidate(s); e != nil {
panic(e)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package session
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/gin-gonic/gin"
"net/http"
)
const (
sessionKeySecurity = "Security"
)
type sessionCtxKey struct {}
// Get returns Session stored in given context. May return nil
func Get(c context.Context) *Session {
session, _ := c.Value(sessionCtxKey{}).(*Session)
return session
}
// MustSet is the panicking version of Set
func MustSet(c context.Context, s *Session) {
if e := Set(c, s); e != nil {
panic(e)
}
}
// Set given Session into given context. The function returns error if the given context is not backed by utils.MutableContext.
func Set(c context.Context, s *Session) error {
mc := utils.FindMutableContext(c)
if mc == nil {
return security.NewInternalError(fmt.Sprintf(`unable to set session into context: given context [%T] is not mutable`, c))
}
mc.Set(sessionCtxKey{}, s)
return nil
}
type Manager struct {
name string
store Store
}
func NewManager(sessionName string, store Store) *Manager {
return &Manager{
name: sessionName,
store: store,
}
}
// SessionHandlerFunc provide middleware for basic session management
func (m *Manager) SessionHandlerFunc() gin.HandlerFunc {
return func(c *gin.Context) {
// defer is FILO
defer m.saveSession(c)
defer c.Next()
var id string
if cookie, err := c.Request.Cookie(m.name); err == nil {
id = cookie.Value
}
session, err := m.store.WithContext(c).Get(id, m.name)
// If session store is not operating properly, we cannot continue for misc that needs session
if err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
if session != nil && session.isNew {
logger.WithContext(c).Debugf("New Session %s", session.id)
http.SetCookie(c.Writer, NewCookie(session.Name(), session.id, session.options, c.Request))
}
if e := Set(c, session); e != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
}
}
// AuthenticationPersistenceHandlerFunc provide middleware to load security from session and save it at end
func (m *Manager) AuthenticationPersistenceHandlerFunc() gin.HandlerFunc {
return func(c *gin.Context) {
// defer is FILO
defer m.persistAuthentication(c)
defer c.Next()
// load security from session
current := Get(c)
if current == nil {
// no session found in current ctx, do nothing
return
}
if auth, ok := current.Get(sessionKeySecurity).(security.Authentication); ok {
security.MustSet(c, auth)
} else {
security.MustSet(c, nil)
}
}
}
func (m *Manager) saveSession(c *gin.Context) {
session := Get(c)
if session == nil {
return
}
err := m.store.WithContext(c).Save(session)
if err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
}
}
func (m *Manager) persistAuthentication(c *gin.Context) {
session := Get(c)
if session == nil {
return
}
auth := security.Get(c)
session.Set(sessionKeySecurity, auth)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package session
import (
"encoding/gob"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/passwd"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/template"
"go.uber.org/fx"
"path"
"time"
)
var logger = log.New("SEC.Session")
var Module = &bootstrap.Module{
Name: "session",
Precedence: security.MinSecurityPrecedence + 10,
Options: []fx.Option{
fx.Provide(security.BindSessionProperties),
fx.Provide(provideSessionStore),
fx.Invoke(register),
},
}
func init() {
bootstrap.Register(Module)
GobRegister()
security.GobRegister()
passwd.GobRegister()
template.RegisterGlobalModelValuer(template.ModelKeySession, template.ContextModelValuer(Get))
}
func GobRegister() {
gob.Register([]interface{}{})
}
type storeDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
SecRegistrar security.Registrar `optional:"true"`
SessionProps security.SessionProperties
ServerProps web.ServerProperties `optional:"true"`
ClientFactory redis.ClientFactory `optional:"true"`
SettingReader security.GlobalSettingReader `optional:"true"`
}
func provideSessionStore(di storeDI) Store {
if di.SecRegistrar == nil || di.ClientFactory == nil {
return nil
}
redisClient, e := di.ClientFactory.New(di.AppContext, func(opt *redis.ClientOption) {
opt.DbIndex = di.SessionProps.DbIndex
})
if e != nil {
panic(e)
}
return NewRedisStore(redisClient, func(opt *StoreOption) {
opt.SettingReader = di.SettingReader
opt.Options.Path = path.Clean(di.SessionProps.Cookie.Path)
opt.Options.Domain = di.SessionProps.Cookie.Domain
opt.Options.MaxAge = di.SessionProps.Cookie.MaxAge
opt.Options.Secure = di.SessionProps.Cookie.Secure
opt.Options.HttpOnly = di.SessionProps.Cookie.HttpOnly
opt.Options.SameSite = di.SessionProps.Cookie.SameSite()
opt.Options.IdleTimeout = time.Duration(di.SessionProps.IdleTimeout)
opt.Options.AbsoluteTimeout = time.Duration(di.SessionProps.AbsoluteTimeout)
})
}
type initDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
SecRegistrar security.Registrar `optional:"true"`
SessionProps security.SessionProperties
SessionStore Store `optional:"true"`
SessionSettingService SettingService `optional:"true"`
}
func register(di initDI) {
if di.SecRegistrar != nil && di.SessionStore != nil {
configurer := newSessionConfigurer(di.SessionProps, di.SessionStore)
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package session
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/session/common"
"github.com/google/uuid"
"net/http"
"strings"
"time"
)
// Default flashes key.
const flashesKey = "_flash"
const createdTimeKey = "_created"
type Session struct {
//Used to indicate if the session values has been modified - and should be saved
dirty bool
//Updated every time the session is accessed. Used to calculate timeout
lastAccessed time.Time
// The id of the session, generated by stores. It should not be used for
// user data.
id string
// values contains the user-data for the session.
// Because the value is declared as interface, any concrete type that is stored in the values map need to register with gob
// if used with a store that serializes using gob. See NewRedisStore
// Should only be set through setter and not set directly
values map[interface{}]interface{}
// Should only be modified when session is created.
options *Options
isNew bool
store Store
name string
originalAuth security.Authentication
}
type Options struct {
Path string
Domain string
// Determines how the cookie's "Max-Age" attribute will be set
// MaxAge=0 means no 'Max-Age' attribute specified.
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
// MaxAge>0 means Max-Age attribute present and given in seconds
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
IdleTimeout time.Duration
AbsoluteTimeout time.Duration
}
func NewSession(store Store, name string) *Session {
opts := *store.Options() // a copy of the store's option
return &Session{
values: make(map[interface{}]interface{}),
store: store,
name: name,
options: &opts,
isNew: true,
dirty: false,
}
}
func CreateSession(store Store, name string) *Session {
s := NewSession(store, name)
s.lastAccessed = time.Now()
s.values[createdTimeKey] = time.Now()
s.id = uuid.New().String()
return s
}
// NewCookie returns an http.Cookie with the options set. It also sets
// the Expires field calculated based on the MaxAge value, for Internet
// Explorer compatibility.
func NewCookie(name, value string, options *Options, r *http.Request) *http.Cookie {
cookie := newCookieFromOptions(name, value, options)
if options.MaxAge > 0 {
d := time.Duration(options.MaxAge) * time.Second
cookie.Expires = time.Now().Add(d)
} else if options.MaxAge < 0 {
// Set it to the past to expire now.
cookie.Expires = time.Unix(1, 0)
}
protoHeader := r.Header.Get("X-Forwarded-Proto")
if !options.Secure {
cookie.Secure = strings.Contains(protoHeader, "https")
}
return cookie
}
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
// GetID returns the name used to register the session.
func (s *Session) GetID() string {
return s.id
}
// Name returns the name used to register the session.
func (s *Session) Name() string {
return s.name
}
// Get returns the session value associated to the given key.
func (s *Session) Get(key interface{}) interface{} {
return s.values[key]
}
// Set sets the session value associated to the given key.
func (s *Session) Set(key interface{}, val interface{}) {
s.values[key] = val
s.SetDirty()
}
// Delete removes the session value associated to the given key.
func (s *Session) Delete(key interface{}) {
if _, ok := s.values[key]; ok {
delete(s.values, key)
s.SetDirty()
}
}
// Clear deletes all values in the session.
func (s *Session) Clear() {
s.values = make(map[interface{}]interface{})
s.SetDirty()
}
// Flashes returns a slice of flash messages from the session.
//
// A single variadic argument is accepted, and it is optional: it defines
// the flash key. If not defined "_flash" is used by default.
func (s *Session) Flashes(flashKey ...string) []interface{} {
defer s.SetDirty()
var flashes []interface{}
key := flashesKey
if len(flashKey) > 0 {
key = flashKey[0]
}
if v, ok := s.values[key]; ok {
// Drop the flashes and return it.
delete(s.values, key)
flashes = v.([]interface{})
}
return flashes
}
// Flash get the last flash message of given key. It internally uses Flashes
func (s *Session) Flash(key string) (ret interface{}) {
values := s.Flashes(key)
if len(values) > 0 {
ret = values[len(values) - 1]
}
return
}
// AddFlash adds a flash message to the session.
//
// A single variadic argument is accepted, and it is optional: it defines
// the flash key. If not defined "_flash" is used by default.
func (s *Session) AddFlash(value interface{}, flashKey ...string) {
key := flashesKey
if len(flashKey) > 0 {
key = flashKey[0]
}
var flashes []interface{}
if v, ok := s.values[key]; ok {
flashes = v.([]interface{})
}
s.values[key] = append(flashes, value)
s.SetDirty()
}
func (s *Session) ChangeId() error {
return s.store.ChangeId(s)
}
// Save is a convenience method to save this session. It is the same as calling
// store.Save(request, response, session).
func (s *Session) Save() (err error) {
if !s.dirty {
return
}
err = s.store.Save(s)
return
}
func (s *Session) IsDirty() bool {
return s.dirty
}
func (s *Session) SetDirty() {
s.dirty = true
}
func (s *Session) ExpireNow(ctx context.Context) error {
return s.store.WithContext(ctx).Invalidate(s)
}
func (s *Session) isExpired() bool {
now := time.Now()
canExpire, exp := s.expiration()
if !canExpire {
return false
} else {
return exp.Before(now)
}
}
func (s *Session) createdOn() time.Time {
if t, ok := s.values[createdTimeKey]; ok {
return t.(time.Time)
} else {
return time.Time{}
}
}
func (s *Session) expiration() (canExpire bool, expiration time.Time) {
var timeoutSetting common.TimeoutSetting = 0
var idleExpiration, absExpiration time.Time
if s.options.IdleTimeout > 0 {
idleExpiration = s.lastAccessed.Add(s.options.IdleTimeout)
timeoutSetting = timeoutSetting | common.IdleTimeoutEnabled
}
if s.options.AbsoluteTimeout > 0 {
absExpiration = s.createdOn().Add(s.options.AbsoluteTimeout)
timeoutSetting = timeoutSetting | common.AbsoluteTimeoutEnabled
}
return common.CalculateExpiration(timeoutSetting, idleExpiration, absExpiration)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package session
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
)
type DefaultSettingService struct {
sessionProperty security.SessionProperties
}
func NewDefaultSettingService(p security.SessionProperties) SettingService {
return &DefaultSettingService{
sessionProperty: p,
}
}
func (d *DefaultSettingService) GetMaximumSessions(ctx context.Context) int {
return d.sessionProperty.MaxConcurrentSession
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package session
import (
"bytes"
"context"
"encoding/gob"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/session/common"
"github.com/google/uuid"
"github.com/pkg/errors"
"io"
"net/http"
"strconv"
"strings"
"time"
)
const (
sessionValueField = "values"
sessionOptionField = "options"
)
const (
globalSettingIdleTimeout = "IDLE_SESSION_TIMEOUT_SECS"
globalSettingAbsTimeout = "ABSOLUTE_SESSION_TIMEOUT_SECS"
)
type Store interface {
// Get should return a cached session.
Get(id string, name string) (*Session, error)
// New should create and return a new session.
New(name string) (*Session, error)
// Save should persist session to the underlying store implementation.
Save(s *Session) error
// Invalidate sessions from store.
// It will also remove associations between sessions and its stored principal via RemoveFromPrincipalIndex
Invalidate(sessions ...*Session) error
Options() *Options
ChangeId(s *Session) error
AddToPrincipalIndex(principal string, session *Session) error
RemoveFromPrincipalIndex(principal string, sessions *Session) error
FindByPrincipalName(principal string, sessionName string) ([]*Session, error)
// InvalidateByPrincipalName invalidate all sessions associated with given principal name
InvalidateByPrincipalName(principal, sessionName string) error
// WithContext make a shallow copy of the store with given ctx
WithContext(ctx context.Context) Store
}
// RedisStore
/**
Session is implemented as a HSET in redis.
Session is expired using Redis TTL. The TTL is slightly longer than the expiration time so that when the session is
expired, we can still get the session details (if necessary).
Currently we don't have a need for "on session expired" event. If we do have a need for this in the future,
we can use https://redis.io/topics/notifications and listen on the TTL event. Once caveat of using the redis
notification is that it may not generate a event until the key is being accessed. So if we want to have deterministic
behaviour on when the event is fired, we would need to implement a scheduler ourselves that access these keys that
are expired which will force redis to generate the event.
For each session:
1. HSET with key in the form of "LANAI:SESSION:SESSION:{sessionId}"
This stores session.values, session.options and session.lastAccessedTime as separate fields in the hash set
2. SET with key in the form of "LANAI:SESSION:INDEX:SESSION:{principal}"
This stores the set of session Id for this user. The session Id stored in this set may have been expired or deleted.
if we don't clean up the keys in this set, then that means on each successful login, we need to go through the
content of this set and find the corresponding session - sscan for the set entries, and then hgetall for each entry.
and filter the expired entries.
if we want to clean up the keys in this set, we need to do so with scheduled tasks. We cannot depend on the redis expiring
event, because when we get the event, the session is not available anymore. Therefore we need to introduce other data structures
to keep track of the expiration separately.
The worst case scenario if we don't clean up this set
is when a user opens multiple session without logging out - this will result in these sessions remain in the set even
when they expires. This will result in a penalty when the user logs on the next time.
But if we don't expect a user to have millions of concurrent sessions, this should be insignificant.
If concurrent user limit is set, we don't expect the number of entries to be more than the concurrent user limit
which should be reasonably small.
If concurrent user limit is not set, it can grow large, but that is not a problem due to expiration - i.e. the
set can grow unbounded before any expiration event occurs. And the remedy to that is to apply a concurrent user limit.
The application can use redis SCAN family of commands to make sure that redis is not blocked by a single user's request.
*/
type RedisStore struct {
ctx context.Context
options *Options
connection redis.Client
settingReader security.GlobalSettingReader
}
type StoreOptions func(opt *StoreOption)
type StoreOption struct {
Options
SettingReader security.GlobalSettingReader
}
func NewRedisStore(redisClient redis.Client, options ...StoreOptions) *RedisStore {
gob.Register(time.Time{})
//defaults
opt := StoreOption{
Options: Options{
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteNoneMode,
IdleTimeout: 900 * time.Second,
AbsoluteTimeout: 1800 * time.Second,
},
}
for _, fn := range options {
fn(&opt)
}
return &RedisStore{
ctx: context.Background(),
options: &opt.Options,
connection: redisClient,
settingReader: opt.SettingReader,
}
}
func (s *RedisStore) WithContext(ctx context.Context) Store {
cp := *s
cp.ctx = ctx
return &cp
}
func (s *RedisStore) Options() *Options {
return s.options
}
func (s *RedisStore) Get(id string, name string) (*Session, error) {
if id != "" {
session, err := s.load(id, name)
if err != nil {
return nil, err
}
if session == nil {
return s.New(name)
} else {
return session, nil
}
} else {
return s.New(name)
}
}
// New will create a new session.
func (s *RedisStore) New(name string) (*Session, error) {
session := CreateSession(s, name)
if idle, ok := s.readTimeoutSetting(s.ctx, globalSettingIdleTimeout); ok {
session.options.IdleTimeout = idle
}
if abs, ok := s.readTimeoutSetting(s.ctx, globalSettingAbsTimeout); ok {
session.options.AbsoluteTimeout = abs
}
return session, nil
}
// Save adds a single session to the persistence layer
func (s *RedisStore) Save(session *Session) error {
if session.id == "" {
return errors.New("session id is empty")
}
session.lastAccessed = time.Now()
err := s.save(session)
if err == nil {
session.dirty = false
session.isNew = false
}
return err
}
func (s *RedisStore) Invalidate(sessions ...*Session) error {
for _, session := range sessions {
if cmd := s.connection.Del(s.ctx, common.GetRedisSessionKey(session.Name(), session.GetID())); cmd.Err() != nil {
return cmd.Err()
}
// remove principal index is an optional step
if pName, e := getPrincipalName(session); e == nil && pName != "" {
//ignore error here since even if it can't be deleted from this index, it'll be cleaned up
// on read since the session itself is already deleted successfully
_ = s.RemoveFromPrincipalIndex(pName, session)
}
}
return nil
}
func (s *RedisStore) InvalidateByPrincipalName(principal, sessionName string) error {
sessions, e := s.FindByPrincipalName(principal, sessionName)
if e != nil {
return e
}
return s.Invalidate(sessions...)
}
func (s *RedisStore) FindByPrincipalName(principal string, sessionName string) ([]*Session, error) {
//iterate through the set members using default count
cursor := uint64(0)
var ids []string
for ok := true; ok; ok = cursor != 0 {
cmd := s.connection.SScan(s.ctx, getRedisPrincipalIndexKey(principal, sessionName), cursor, "", 0)
keys, next, err := cmd.Result()
cursor = next
if err != nil {
return nil, err
}
ids = append(ids, keys...)
}
var found []*Session
var expired []interface{}
for _, id := range ids {
session, err := s.load(id, sessionName)
if err != nil {
return nil, err
}
if session == nil {
expired = append(expired, id)
} else {
found = append(found, session)
}
}
//clean up the expired entries from the index
if len(expired) > 0 {
s.connection.SRem(s.ctx, getRedisPrincipalIndexKey(principal, sessionName), expired...)
}
return found, nil
}
func (s *RedisStore) AddToPrincipalIndex(principal string, session *Session) error {
cmd := s.connection.SAdd(s.ctx, getRedisPrincipalIndexKey(principal, session.Name()), session.GetID())
return cmd.Err()
}
func (s *RedisStore) RemoveFromPrincipalIndex(principal string, session *Session) error {
cmd := s.connection.SRem(s.ctx, getRedisPrincipalIndexKey(principal, session.Name()), session.GetID())
return cmd.Err()
}
func (s *RedisStore) ChangeId(session *Session) error {
newId := uuid.New().String()
cmd := s.connection.Rename(s.ctx, common.GetRedisSessionKey(session.Name(), session.GetID()), common.GetRedisSessionKey(session.Name(), newId))
err := cmd.Err()
if err != nil {
return err
}
session.id = newId
return nil
}
func (s *RedisStore) load(id string, name string) (*Session, error) {
key := common.GetRedisSessionKey(name, id)
cmd := s.connection.HGetAll(s.ctx, key)
result, err := cmd.Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil
}
session := NewSession(s, name)
session.id = id
for k, v := range result {
if k == sessionValueField {
err = Deserialize(strings.NewReader(v), &session.values)
} else if k == sessionOptionField {
err = Deserialize(strings.NewReader(v), &session.options)
} else if k == common.SessionLastAccessedField {
timeStamp, e := strconv.ParseInt(v, 10, 0)
session.lastAccessed = time.Unix(timeStamp, 0)
err = e
}
if err != nil {
return nil, err
}
}
session.isNew = false
if session.isExpired() {
return nil, nil
} else {
return session, nil
}
}
func (s *RedisStore) save(session *Session) error {
key := common.GetRedisSessionKey(session.Name(), session.GetID())
var args []interface{}
if session.IsDirty() || session.isNew {
if values, err := Serialize(session.values); err == nil {
args = append(args, sessionValueField, values)
} else {
return err
}
}
if session.isNew {
if options, err := Serialize(session.options); err == nil {
args = append(args, sessionOptionField, options)
//stored separate for easy retrieval
if session.options.IdleTimeout > 0 {
args = append(args, common.SessionIdleTimeoutDuration, session.options.IdleTimeout.String())
}
if session.options.AbsoluteTimeout > 0 {
args = append(args, common.SessionAbsTimeoutTime, session.createdOn().Add(session.options.AbsoluteTimeout).Unix())
}
} else {
return err
}
}
args = append(args, common.SessionLastAccessedField, session.lastAccessed.Unix())
hsetCmd := s.connection.HSet(s.ctx, key, args...)
if hsetCmd.Err() != nil {
return hsetCmd.Err()
}
canExpire, exp := session.expiration()
if canExpire {
expCmd := s.connection.ExpireAt(s.ctx, key, exp)
return expCmd.Err()
}
return nil
}
func (s *RedisStore) readTimeoutSetting(ctx context.Context, key string) (time.Duration, bool) {
if s.settingReader == nil {
return 0, false
}
var secs int
if e := s.settingReader.Read(ctx, key, &secs); e != nil {
return 0, false
}
return time.Second * time.Duration(secs), true
}
func getRedisPrincipalIndexKey(principal string, sessionName string) string {
return fmt.Sprintf("%s:INDEX:%s:%s", common.RedisNameSpace, sessionName, principal)
}
func Serialize(src interface{}) ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(src); err != nil {
return nil, errors.Wrap(err, "Cannot serialize value")
}
return buf.Bytes(), nil
}
func Deserialize(src io.Reader, dst interface{}) error {
dec := gob.NewDecoder(src)
if err := dec.Decode(dst); err != nil {
return errors.Wrap(err, "Cannot serialize value")
}
return nil
}
func getPrincipalName(session *Session) (string, error) {
auth, ok := session.Get(sessionKeySecurity).(security.Authentication)
if !ok {
return "", nil
}
return security.GetUsername(auth)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package security
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/mapping"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
)
/*****************************
webSecurity Impl
******************************/
type webSecurity struct {
context context.Context
routeMatcher web.RouteMatcher
conditionMatcher web.RequestMatcher
handlers []interface{}
features []Feature
shared map[string]interface{}
authenticator Authenticator
applied map[FeatureIdentifier]struct{}
featuresChanged bool
}
func newWebSecurity(ctx context.Context, authenticator Authenticator, shared map[string]interface{}) *webSecurity {
return &webSecurity{
context: ctx,
handlers: []interface{}{},
features: []Feature{},
applied: map[FeatureIdentifier]struct{}{},
shared: shared,
authenticator: authenticator,
}
}
/* WebSecurity interface */
func (ws *webSecurity) Context() context.Context {
if ws.context == nil {
return context.TODO()
}
return ws.context
}
func (ws *webSecurity) Features() []Feature {
return ws.features
}
func (ws *webSecurity) Route(rm web.RouteMatcher) WebSecurity {
if ws.routeMatcher != nil {
ws.routeMatcher = ws.routeMatcher.Or(rm)
} else {
ws.routeMatcher = rm
}
return ws
}
func (ws *webSecurity) Condition(mwcm web.RequestMatcher) WebSecurity {
if ws.conditionMatcher != nil {
ws.conditionMatcher = ws.conditionMatcher.Or(mwcm)
} else {
ws.conditionMatcher = mwcm
}
return ws
}
func (ws *webSecurity) AndCondition(mwcm web.RequestMatcher) WebSecurity {
if ws.conditionMatcher != nil {
ws.conditionMatcher = ws.conditionMatcher.And(mwcm)
} else {
ws.conditionMatcher = mwcm
}
return ws
}
func (ws *webSecurity) With(f Feature) WebSecurity {
existing := ws.Enable(f)
if existing != f {
panic(fmt.Errorf("cannot re-enable feature [%v] using With()", f.Identifier()))
}
return ws
}
func (ws *webSecurity) Add(handlers ...interface{}) WebSecurity {
for i, h := range handlers {
v, err := ws.toAcceptedHandler(h)
if err != nil {
panic(err)
}
handlers[i] = v
}
ws.handlers = append(ws.handlers, handlers...)
return ws
}
func (ws *webSecurity) Remove(handlers ...interface{}) WebSecurity {
for _, h := range handlers {
v, err := ws.toAcceptedHandler(h)
if err != nil {
panic(err)
}
ws.handlers = remove(ws.handlers, v)
}
return ws
}
func (ws *webSecurity) Shared(key string) interface{} {
return ws.shared[key]
}
func (ws *webSecurity) AddShared(key string, value interface{}) error {
if _, exists := ws.shared[key]; exists {
return fmt.Errorf("cannot add shared value to WebSecurity %v: key [%s] already exists", ws, key)
}
ws.shared[key] = value
return nil
}
func (ws *webSecurity) Authenticator() Authenticator {
return ws.authenticator
}
/* FeatureModifier interface */
func (ws *webSecurity) Enable(f Feature) Feature {
if _,exists := ws.applied[f.Identifier()]; exists {
panic(fmt.Errorf("attempt to configure security feature [%v] after it has been applied", f.Identifier()))
}
if i := findFeatureIndex(ws.features, f); i >= 0 {
// already have this feature
return ws.features[i]
}
ws.featuresChanged = true
ws.features = append(ws.features, f)
return f
}
func (ws *webSecurity) Disable(f Feature) {
if i := findFeatureIndex(ws.features, f); i >= 0 {
// already have this feature
ws.featuresChanged = true
copy(ws.features[i:], ws.features[i + 1:])
ws.features[len(ws.features) - 1] = nil
ws.features = ws.features[:len(ws.features) - 1]
}
}
/* WebSecurityReader interface */
func (ws *webSecurity) GetRoute() web.RouteMatcher {
return ws.routeMatcher
}
func (ws *webSecurity) GetCondition() web.RequestMatcher {
return ws.conditionMatcher
}
func (ws *webSecurity) GetHandlers() []interface{} {
return ws.handlers
}
/* WebSecurityMappingBuilder interface */
func (ws *webSecurity) Build() []web.Mapping {
mappings := make([]web.Mapping, len(ws.handlers))
for i, handler := range ws.handlers {
var m web.Mapping
switch v := handler.(type) {
case MiddlewareTemplate:
m = ws.buildFromMiddlewareTemplate(handler.(MiddlewareTemplate))
case SimpleMappingTemplate:
m = ws.buildFromSimpleMappingTemplate(handler.(SimpleMappingTemplate))
// Note: we don't use web.Mapping here because the interface is too simple and may have false positive
case web.RoutedMapping:
m = v
case web.StaticMapping:
m = v
case web.MiddlewareMapping:
m = v
default:
panic(fmt.Errorf("unable to build security mappings from unsupported WebSecurity handler [%T]", v))
}
mappings[i] = m
}
return mappings
}
// Other interfaces
func (ws *webSecurity) String() string {
fids := make([]FeatureIdentifier, len(ws.features))
for i, f := range ws.features {
fids[i] = f.Identifier()
}
return fmt.Sprintf("matcher=%v, condition=%v, features=%v", ws.routeMatcher, ws.conditionMatcher, fids)
}
func (ws *webSecurity) GoString() string {
return ws.String()
}
// unexported
func (ws *webSecurity) buildFromMiddlewareTemplate(tmpl MiddlewareTemplate) web.Mapping {
builder := (*middleware.MappingBuilder)(tmpl)
if ws.routeMatcher == nil {
ws.routeMatcher = matcher.AnyRoute()
}
if builder.GetRouteMatcher() == nil {
builder = builder.ApplyTo(ws.routeMatcher)
}
if ws.conditionMatcher != nil && builder.GetCondition() == nil {
builder = builder.WithCondition(ws.conditionMatcher)
}
return builder.Build()
}
func (ws *webSecurity) buildFromSimpleMappingTemplate(tmpl SimpleMappingTemplate) web.Mapping {
builder := (*mapping.MappingBuilder)(tmpl)
if ws.routeMatcher == nil {
ws.routeMatcher = matcher.AnyRoute()
}
if ws.conditionMatcher != nil && builder.GetCondition() == nil {
builder = builder.Condition(ws.conditionMatcher)
}
return builder.Build()
}
// toAcceptedHandler perform validation and some type casting on the interface
func (ws *webSecurity) toAcceptedHandler(v interface{}) (interface{}, error) {
// non-interface types
if casted, ok := v.(*middleware.MappingBuilder); ok {
return MiddlewareTemplate(casted), nil
} else if casted, ok := v.(*mapping.MappingBuilder); ok {
return SimpleMappingTemplate(casted), nil
}
// interface types
switch v.(type) {
case MiddlewareTemplate:
case SimpleMappingTemplate:
// Note: we don't use web.Mapping here because the interface is too simple and may have false positive
case web.RoutedMapping:
case web.StaticMapping:
case web.MiddlewareMapping:
default:
return nil, fmt.Errorf("unsupported WebSecurity handler [%T]", v)
}
return v, nil
}
func remove(slice []interface{}, item interface{}) []interface{} {
for i,obj := range slice {
if obj != item {
continue
}
copy(slice[i:], slice[i + 1:])
slice[len(slice) - 1] = nil
return slice[:len(slice)-1]
}
return slice
}
func findFeatureIndex(slice []Feature, f Feature) int {
for i, obj := range slice {
if f.Identifier() == obj.Identifier() {
return i
}
}
return -1
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package swagger
import (
"encoding/json"
"errors"
)
const (
OASv2 = "2.0"
OASv30 = "3.0.0"
OASv31 = "3.1.0"
)
type OASVersion string
type oasDoc struct {
OASLegacyVer OASVersion `json:"swagger"`
OAS3Ver OASVersion `json:"openapi"`
}
type OpenApiSpec struct {
Version OASVersion `json:"Version"`
OAS2 *OAS2 `json:"OAS2"`
OAS3 *OAS3 `json:"OAS3"`
}
func (s *OpenApiSpec) UnmarshalJSON(data []byte) error {
var doc oasDoc
if e := json.Unmarshal(data, &doc); e != nil {
return e
}
var specPtr interface{}
switch {
case doc.OAS3Ver == OASv30, doc.OAS3Ver == OASv31:
s.Version = doc.OAS3Ver
s.OAS3 = &OAS3{}
specPtr = s.OAS3
case len(doc.OAS3Ver) == 0 && doc.OASLegacyVer == OASv2:
s.Version = OASv2
s.OAS2 = &OAS2{}
specPtr = s.OAS2
default:
return errors.New("unknown OAS document version")
}
return json.Unmarshal(data, specPtr)
}
// OAS2 is Swagger 2.0 Specification
// https://swagger.io/docs/specification/2-0/basic-structure/
// https://github.com/OAI/OpenAPI-Specification/blob/main/versions/2.0.md
type OAS2 struct {
OpenAPIVersion string `json:"swagger"`
Info OAS2Info `json:"info"`
Host string `json:"host,omitempty"`
BasePath string `json:"basePath,omitempty"`
Schemes []string `json:"schemes,omitempty"`
Consumes []string `json:"consumes,omitempty"`
Produces []string `json:"produces,omitempty"`
Paths map[string]interface{} `json:"paths,omitempty"`
Definitions map[string]interface{} `json:"definitions,omitempty"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
Responses map[string]interface{} `json:"responses,omitempty"`
SecDefs map[string]interface{} `json:"securityDefinitions,omitempty"`
Security []interface{} `json:"security,omitempty"`
Tags []interface{} `json:"tags,omitempty"`
ExtDocs map[string]interface{} `json:"externalDocs,omitempty"`
}
type OAS2Info struct {
Title string `json:"title"`
Description string `json:"description,omitempty"`
TermsOfService string `json:"termsOfService,omitempty"`
Contact map[string]interface{} `json:"contact,omitempty"`
License map[string]interface{} `json:"license,omitempty"`
Version string `json:"version,omitempty"`
}
// OAS3 is Swagger 3.0 Specification
// https://swagger.io/docs/specification/basic-structure/
// https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#schema
type OAS3 struct {
OpenAPIVersion string `json:"openapi"`
Info OAS3Info `json:"info"`
JsonDialect string `json:"jsonSchemaDialect,omitempty"`
Servers []OAS3Server `json:"servers,omitempty"`
Paths map[string]interface{} `json:"paths,omitempty"`
WebHooks map[string]interface{} `json:"webhooks,omitempty"`
Components map[string]interface{} `json:"components,omitempty"`
Security []interface{} `json:"security,omitempty"`
Tags []interface{} `json:"tags,omitempty"`
ExtDocs map[string]interface{} `json:"externalDocs,omitempty"`
}
type OAS3Info struct {
Title string `json:"title"`
Summary string `json:"summary,omitempty"`
Description string `json:"description,omitempty"`
TermsOfService string `json:"termsOfService,omitempty"`
Contact map[string]interface{} `json:"contact,omitempty"`
License map[string]interface{} `json:"license,omitempty"`
Version string `json:"version,omitempty"`
}
type OAS3Server struct {
URL string `json:"url"`
Description string `json:"description,omitempty"`
Variables map[string]interface{} `json:"variables,omitempty"`
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package swagger
import (
"context"
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/assets"
"github.com/cisco-open/go-lanai/pkg/web/rest"
"io/fs"
"net/http"
"strings"
)
type UiConfiguration struct {
ApisSorter string `json:"apisSorter"`
DeepLinking bool `json:"deepLinking"`
DefaultModelExpandDepth int `json:"defaultModelExpandDepth"`
DefaultModelRendering string `json:"defaultModelRendering"`
DefaultModelsExpandDepth int `json:"defaultModelsExpandDepth"`
DisplayOperationId bool `json:"displayOperationId"`
DisplayRequestDuration bool `json:"displayRequestDuration"`
DocExpansion string `json:"docExpansion"`
Filter bool `json:"filter"`
JsonEditor bool `json:"jsonEditor"`
OperationsSorter string `json:"operationsSorter"`
ShowExtensions bool `json:"showExtensions"`
ShowRequestHeaders bool `json:"showRequestHeaders"`
SupportedSubmitMethods []string `json:"supportedSubmitMethods"`
TagsSorter string `json:"tagsSorter"`
ValidatorUrl string `json:"validatorUrl"`
Title string `json:"title"`
}
type SsoConfiguration struct {
Enabled bool `json:"enabled"`
AuthorizeUrl string `json:"authorizeUrl"`
ClientId string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
TokenUrl string `json:"tokenUrl"`
AdditionalParams []ParamMeta `json:"additionalParameters"`
}
type ParamMeta struct {
Name string `json:"name"`
DisplayName string `json:"displayName"`
CandidateSourceUrl string `json:"candidateSourceUrl"`
CandidateJsonPath string `json:"candidateJsonPath"`
}
type Resource struct {
Name string `json:"name"`
Location string `json:"location"`
Url string `json:"url"`
SwaggerVersion string `json:"swaggerVersion"`
}
//goland:noinspection GoNameStartsWithPackageName
type SwaggerController struct {
properties *SwaggerProperties
buildInfoResolver bootstrap.BuildInfoResolver
docLoader *OASDocLoader
}
func NewSwaggerController(props SwaggerProperties, resolver bootstrap.BuildInfoResolver) *SwaggerController {
return newSwaggerController(props, resolver)
}
func newSwaggerController(props SwaggerProperties, resolver bootstrap.BuildInfoResolver, searchFS ...fs.FS) *SwaggerController {
return &SwaggerController{
properties: &props,
buildInfoResolver: resolver,
docLoader: newOASDocLoader(props.Spec, searchFS...),
}
}
func (c *SwaggerController) Mappings() []web.Mapping {
return []web.Mapping{
assets.New("/swagger/static", "generated/"),
web.NewSimpleMapping("swagger-ui", "", "/swagger", http.MethodGet, nil, c.swagger),
rest.New("swagger-configuration-ui").Get("/swagger-resources/configuration/ui").EndpointFunc(c.configurationUi).Build(),
rest.New("swagger-configuration-security").Get("/swagger-resources/configuration/security").EndpointFunc(c.configurationSecurity).Build(),
rest.New("swagger-configuration-sso").Get("/swagger-resources/configuration/security/sso").EndpointFunc(c.configurationSso).Build(),
rest.New("swagger-resources").Get("/swagger-resources").EndpointFunc(c.resources).Build(),
web.NewSimpleMapping("swagger-sso-redirect", "", "swagger-sso-redirect.html", http.MethodGet, nil, c.swaggerRedirect),
web.NewSimpleMapping("swagger-spec", "", "/v2/api-docs", http.MethodGet, nil, c.oas2Doc),
web.NewSimpleMapping("oas3-spec", "", "/v3/api-docs", http.MethodGet, nil, c.oas3Doc),
}
}
func (c *SwaggerController) configurationUi(_ context.Context, _ web.EmptyRequest) (response interface{}, err error) {
response = UiConfiguration{
DeepLinking: true,
DisplayOperationId: false,
DefaultModelsExpandDepth: 1,
DefaultModelExpandDepth: 1,
DefaultModelRendering: "example",
DisplayRequestDuration: false,
DocExpansion: "none",
Filter: false,
OperationsSorter: "alpha",
ShowExtensions: false,
TagsSorter: "alpha",
ValidatorUrl: "",
SupportedSubmitMethods: []string{"get", "put", "post", "delete", "options", "head", "patch", "trace"},
Title: c.properties.UI.Title,
}
return
}
func (c *SwaggerController) swagger(w http.ResponseWriter, r *http.Request) {
fs := http.FS(Content)
file, err := fs.Open("generated/swagger-ui.html")
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
fileInfo, err := file.Stat()
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
http.ServeContent(w, r, fileInfo.Name(), fileInfo.ModTime(), file)
}
func (c *SwaggerController) configurationSecurity(_ context.Context, _ web.EmptyRequest) (response interface{}, err error) {
response = struct{}{}
return
}
func (c *SwaggerController) configurationSso(_ context.Context, _ web.EmptyRequest) (response interface{}, err error) {
var params []ParamMeta
for _, v := range c.properties.Security.Sso.AdditionalParams {
params = append(params, ParamMeta{
Name: v.Name,
DisplayName: v.DisplayName,
CandidateSourceUrl: v.CandidateSourceUrl,
CandidateJsonPath: v.CandidateJsonPath,
})
}
response = SsoConfiguration{
Enabled: c.properties.Security.Sso.ClientId != "",
TokenUrl: fmt.Sprintf("%s%s", c.properties.Security.Sso.BaseUrl, c.properties.Security.Sso.TokenPath),
AuthorizeUrl: fmt.Sprintf("%s%s", c.properties.Security.Sso.BaseUrl, c.properties.Security.Sso.AuthorizePath),
ClientId: c.properties.Security.Sso.ClientId,
ClientSecret: c.properties.Security.Sso.ClientSecret,
AdditionalParams: params,
}
return
}
func (c *SwaggerController) resources(_ context.Context, _ web.EmptyRequest) (response interface{}, err error) {
response = []Resource{
{
Name: "platform",
Url: "/v3/api-docs?group=platform",
SwaggerVersion: "3.0",
Location: "/v3/api-docs?group=platform",
},
}
return
}
func (c *SwaggerController) oas2Doc(w http.ResponseWriter, r *http.Request) {
var err error
//nolint:contextcheck // we want r.Context() at the end of execution
defer func() {
if err != nil {
logger.WithContext(r.Context()).Errorf("Failed to serve OAS document: %v", err)
w.WriteHeader(http.StatusInternalServerError)
}
}()
doc, e := c.docLoader.Load()
if e != nil {
err = e
return
}
switch oas, e := c.process(doc, r); {
case e == nil && doc.Version == OASv2:
// write to response
w.Header().Set(web.HeaderContentType, "application/json")
err = json.NewEncoder(w).Encode(oas)
case e == nil:
err = fmt.Errorf("OAS3 document is not supported by /v2 endpoint")
default:
err = e
}
}
func (c *SwaggerController) oas3Doc(w http.ResponseWriter, r *http.Request) {
var err error
//nolint:contextcheck // we want r.Context() at the end of execution
defer func() {
if err != nil {
logger.WithContext(r.Context()).Errorf("Failed to serve OAS document: %v", err)
w.WriteHeader(http.StatusInternalServerError)
}
}()
doc, e := c.docLoader.Load()
if e != nil {
err = e
return
}
oas, e := c.process(doc, r)
if e != nil {
err = e
return
}
// write to response
w.Header().Set(web.HeaderContentType, "application/json")
err = json.NewEncoder(w).Encode(oas)
}
func (c *SwaggerController) swaggerRedirect(w http.ResponseWriter, r *http.Request) {
fs := http.FS(Content)
path := "generated/swagger-sso-redirect.html"
file, err := fs.Open(path)
if err != nil {
logger.WithContext(r.Context()).Errorf("Unable to open file '%s': %v", path, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
fileInfo, err := file.Stat()
if err != nil {
logger.WithContext(r.Context()).Errorf("Unable to stat file '%s': %v", path, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
http.ServeContent(w, r, fileInfo.Name(), fileInfo.ModTime(), file)
}
func (c *SwaggerController) msxVersion() string {
if c.buildInfoResolver != nil {
return c.buildInfoResolver.Resolve().Version
}
if strings.ToLower(bootstrap.BuildVersion) == "unknown" {
return ""
}
return bootstrap.BuildVersion
}
func (c *SwaggerController) process(doc *OpenApiSpec, r *http.Request) (interface{}, error) {
switch doc.Version {
case OASv2:
return doc.OAS2, c.processOAS2(doc.OAS2, r)
case OASv30, OASv31:
return doc.OAS3, c.processOAS3(doc.OAS3, r)
}
return nil, fmt.Errorf("unknown OAS document version")
}
func (c *SwaggerController) processOAS2(oas *OAS2, r *http.Request) error {
// host
var host string
fwdAddress := r.Header.Get("X-Forwarded-Host") // capitalisation doesn't matter
if fwdAddress != "" {
ips := strings.Split(fwdAddress, ",")
host = strings.TrimSpace(ips[0])
} else {
host = r.Host
}
oas.Host = host
// version
oas.Info.Version = c.msxVersion()
return nil
}
func (c *SwaggerController) processOAS3(oas *OAS3, r *http.Request) error {
// version
oas.Info.Version = c.msxVersion()
// host
oas.Servers = nil
fwdAddr := strings.TrimSpace(r.Header.Get("X-Forwarded-Host")) // capitalisation doesn't matter
if len(fwdAddr) != 0 {
host := strings.Split(fwdAddr, ",")[0]
fwdProto := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto"))
schema := "http"
if len(fwdProto) != 0 {
schema = strings.Split(fwdProto, ",")[0]
}
serverUrl := fmt.Sprintf("%s://%s", strings.TrimSpace(schema), strings.TrimSpace(host))
if c.properties.BasePath != "" {
basePath := strings.TrimSpace(c.properties.BasePath)
if !strings.HasPrefix(c.properties.BasePath, "/") {
basePath = fmt.Sprintf("/%s", basePath)
}
serverUrl = fmt.Sprintf("%s%s", serverUrl, basePath)
}
server := OAS3Server{
URL: serverUrl,
Description: "Current API Server",
}
oas.Servers = append([]OAS3Server{server}, oas.Servers...)
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package swagger
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/discovery"
)
const (
TagSwaggerPath = "swaggerPath"
)
func newSwaggerInfoDiscoveryCustomizer() discovery.ServiceRegistrationCustomizer {
return swaggerInfoDiscoveryCustomizer{}
}
type swaggerInfoDiscoveryCustomizer struct {}
func (s swaggerInfoDiscoveryCustomizer) Customize(_ context.Context, reg discovery.ServiceRegistration) {
reg.AddTags(fmt.Sprintf("%s=%s", TagSwaggerPath, "/swagger"))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package swagger
import (
"encoding/json"
"errors"
"fmt"
"github.com/ghodss/yaml"
"io"
"io/fs"
"os"
"path"
"strings"
)
type OASDocLoader struct {
path string
searchFS []fs.FS
}
func newOASDocLoader(path string, searchFS ...fs.FS) *OASDocLoader {
if len(searchFS) == 0 {
searchFS = []fs.FS{os.DirFS(".")}
}
return &OASDocLoader{
path: path,
searchFS: searchFS,
}
}
func (l OASDocLoader) Load() (*OpenApiSpec, error) {
// find docs file
var file fs.File
var e error
for _, fsys := range l.searchFS {
switch file, e = fsys.Open(l.path); {
case errors.Is(e, fs.ErrNotExist):
continue
case e != nil:
return nil, e
}
break
}
if file == nil {
return nil, fs.ErrNotExist
}
defer func() { _ = file.Close() }()
// load docs
var oas OpenApiSpec
switch fileExt := strings.ToLower(path.Ext(l.path)); fileExt {
case ".yml", ".yaml":
data, e := io.ReadAll(file)
if e != nil {
return nil, e
}
if e := yaml.Unmarshal(data, &oas); e != nil {
return nil, e
}
case ".json", ".json5":
decoder := json.NewDecoder(file)
if e := decoder.Decode(&oas); e != nil {
return nil, e
}
default:
return nil, fmt.Errorf("unsupported file extension for OAS document: %s", fileExt)
}
return &oas, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package swagger
import (
"embed"
"fmt"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/pkg/errors"
"go.uber.org/fx"
)
//go:generate npm install --prefix nodejs
//go:generate npm run --prefix nodejs build --output_dir=../generated
//go:embed generated/*
var Content embed.FS
//go:embed defaults-swagger.yml
var defaultConfigFS embed.FS
var logger = log.New("Swagger")
var Module = &bootstrap.Module{
Name: "swagger",
Precedence: bootstrap.SwaggerPrecedence,
PriorityOptions: []fx.Option{
fx.Invoke(configureSecurity),
},
Options: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(
bindSwaggerProperties,
fx.Annotate(newSwaggerInfoDiscoveryCustomizer, fx.ResultTags(fmt.Sprintf(`group:"%s"`, discovery.FxGroup))),
),
fx.Invoke(initialize),
},
}
func Use() {
bootstrap.Register(Module)
}
type initDI struct {
fx.In
Registrar *web.Registrar
Properties SwaggerProperties
Resolver bootstrap.BuildInfoResolver `optional:"true"`
}
func initialize(di initDI) {
di.Registrar.MustRegister(Content)
di.Registrar.MustRegister(NewSwaggerController(di.Properties, di.Resolver))
}
func bindSwaggerProperties(ctx *bootstrap.ApplicationContext) SwaggerProperties {
props := NewSwaggerSsoProperties()
if err := ctx.Config().Bind(props, SwaggerPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind SwaggerSsoProperties"))
}
return *props
}
type secDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
Properties SwaggerProperties
}
// configureSecurity register security.Configurer that control how security works on endpoints
func configureSecurity(di secDI) {
if di.SecRegistrar != nil && di.Properties.Security.SecureDocs {
di.SecRegistrar.Register(&swaggerSecurityConfigurer{})
}
}
type DiscoveryCustomizerDIOut struct {
fx.Out
Customizer discovery.ServiceRegistrationCustomizer `group:"discovery_customizer"`
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package swagger
const SwaggerPrefix = "swagger"
type SwaggerProperties struct {
BasePath string `json:"base-path"`
Spec string `json:"spec"`
Security SwaggerSecurityProperties `json:"security"`
UI SwaggerUIProperties `json:"ui"`
}
type SwaggerSecurityProperties struct {
SecureDocs bool `json:"secure-docs"`
Sso SwaggerSsoProperties `json:"sso"`
}
type SwaggerSsoProperties struct {
BaseUrl string `json:"base-url"`
TokenPath string `json:"token-path"`
AuthorizePath string `json:"authorize-path"`
ClientId string `json:"client-id"`
ClientSecret string `json:"client-secret"`
AdditionalParams []ParameterProperties `json:"additional-params" binding:"omitempty"`
}
type ParameterProperties struct {
Name string `json:"name"`
DisplayName string `json:"display-name"`
CandidateSourceUrl string `json:"candidate-source-url"`
CandidateJsonPath string `json:"candidate-json-path"`
}
type SwaggerUIProperties struct {
Title string `json:"title"`
}
func NewSwaggerSsoProperties() *SwaggerProperties {
return &SwaggerProperties{}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package swagger
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/tokenauth"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
)
type swaggerSecurityConfigurer struct {
}
func (c *swaggerSecurityConfigurer) Configure(ws security.WebSecurity) {
// DSL style example
// for REST API
ws.Route(matcher.RouteWithPattern("/v2/api-docs").Or(matcher.RouteWithPattern("/v3/api-docs"))).
With(tokenauth.New()).
With(access.New().
Request(matcher.AnyRequest()).AllowIf(swaggerSpecAccessControl),
).
With(errorhandling.New())
}
func swaggerSpecAccessControl(auth security.Authentication) (decision bool, reason error) {
oa, ok := auth.(oauth2.Authentication)
if !ok {
return false, security.NewInsufficientAuthError("expected token authentication")
}
if oa.UserAuthentication() == nil {
return false, security.NewInsufficientAuthError("expected oauth user authentication")
}
if !(oa.OAuth2Request().Approved() && oa.OAuth2Request().Scopes().Has("read") && oa.OAuth2Request().Scopes().Has("write")) {
return false, security.NewInsufficientAuthError("expected read and write scope")
}
//and must be authenticated
return access.Authenticated(auth)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tenancy
import (
"container/list"
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
r "github.com/go-redis/redis/v8"
"github.com/google/uuid"
"strings"
)
const (
errTmplNotLoaded = `tenancy is not loaded`
)
//goland:noinspection GoNameStartsWithPackageName
type TenancyAccessor struct {
cachedRootID string
rc redis.Client
}
func newAccessor(rc redis.Client) *TenancyAccessor {
return &TenancyAccessor{
cachedRootID: "",
rc: rc,
}
}
func (a *TenancyAccessor) GetParent(ctx context.Context, tenantId string) (string, error) {
if !a.IsLoaded(ctx) {
return "", errors.New(errTmplNotLoaded)
}
gteValue := BuildSpsString(tenantId, IsChildrenOfPredict)
lteValue := BuildSpsString(tenantId, IsChildrenOfPredict, RedisZsetMaxByte)
zrange := &r.ZRangeBy{Min: ZInclusive(gteValue), Max: ZInclusive(lteValue)}
cmd := a.rc.ZRangeByLex(ctx, ZsetKey, zrange)
relations := cmd.Val()
if len(relations) == 0 {
return "", nil
} else if len(relations) > 1 {
return "", errors.New(fmt.Sprintf("Tenant should only have one parent, but tenant with Id %s has %d ", tenantId, len(relations)))
} else {
p, err := GetObjectOfSpo(relations[0])
return p, err
}
}
func (a *TenancyAccessor) GetChildren(ctx context.Context, tenantId string) ([]string, error) {
if !a.IsLoaded(ctx) {
return nil, errors.New(errTmplNotLoaded)
}
gteValue := BuildSpsString(tenantId, IsParentOfPredict)
lteValue := BuildSpsString(tenantId, IsParentOfPredict, RedisZsetMaxByte)
zrange := &r.ZRangeBy{Min: ZInclusive(gteValue), Max: ZInclusive(lteValue)}
cmd := a.rc.ZRangeByLex(ctx, ZsetKey, zrange)
var children = make([]string, len(cmd.Val()))
for i, relation := range cmd.Val() {
child, err := GetObjectOfSpo(relation)
if err != nil {
return nil, err
}
children[i] = child
}
return children, nil
}
func (a *TenancyAccessor) GetAncestors(ctx context.Context, tenantId string) ([]string, error) {
if !a.IsLoaded(ctx) {
return nil, errors.New(errTmplNotLoaded)
}
var ancestors = make([]string, 0)
p, err := a.GetParent(ctx, tenantId)
for p != "" && err == nil {
ancestors = append(ancestors, p)
p, err = a.GetParent(ctx, p)
}
if err != nil {
return nil, err
}
return ancestors, nil
}
func (a *TenancyAccessor) GetDescendants(ctx context.Context, tenantId string) ([]string, error) {
if !a.IsLoaded(ctx) {
return nil, errors.New(errTmplNotLoaded)
}
descendants := make([]string, 0)
idsToVisit := list.New()
idsToVisit.PushBack(tenantId)
for idsToVisit.Len() > 0 {
cmds, err := a.rc.Pipelined(ctx, func(pipeliner r.Pipeliner) error {
for idsToVisit.Len() > 0 {
id := idsToVisit.Front()
gteValue := BuildSpsString(id.Value.(string), IsParentOfPredict)
lteValue := BuildSpsString(id.Value.(string), IsParentOfPredict, RedisZsetMaxByte)
zrange := &r.ZRangeBy{Min: ZInclusive(gteValue), Max: ZInclusive(lteValue)}
pcmd := pipeliner.ZRangeByLex(ctx, ZsetKey, zrange)
if pcmd.Err() == nil {
idsToVisit.Remove(id)
} else {
return pcmd.Err()
}
}
return nil
})
if err != nil {
return nil, err
}
var children []string
for _, c := range cmds {
for _, relation := range c.(*r.StringSliceCmd).Val() {
child, err := GetObjectOfSpo(relation)
if err != nil {
return nil, err
}
children = append(children, child)
}
}
descendants = append(descendants, children...)
for _, child := range children {
idsToVisit.PushBack(child)
}
}
return descendants, nil
}
// GetRoot will return the root tenantID.
// Because the root tenantId won't change once system is started, we can cache
// it after first successful read.
func (a *TenancyAccessor) GetRoot(ctx context.Context) (string, error) {
if !a.IsLoaded(ctx) {
return "", errors.New(errTmplNotLoaded)
}
if a.cachedRootID != "" {
return a.cachedRootID, nil
}
cmd := a.rc.Get(ctx, RootTenantKey)
if cmd.Err() != nil {
a.cachedRootID = ""
} else {
a.cachedRootID = cmd.Val()
}
return a.cachedRootID, cmd.Err()
}
func (a *TenancyAccessor) IsLoaded(ctx context.Context) bool {
cmd := a.rc.Get(ctx, StatusKey)
if cmd.Err() != nil {
return false
}
return strings.HasPrefix(cmd.Val(), STATUS_LOADED)
}
func (a *TenancyAccessor) GetTenancyPath(ctx context.Context, tenantId string) ([]uuid.UUID, error) {
current, err := uuid.Parse(tenantId)
if err != nil {
return nil, err
}
path := []uuid.UUID{current}
ancestors, err := a.GetAncestors(ctx, tenantId)
if err != nil {
return nil, err
}
for _, str := range ancestors {
id, err := uuid.Parse(str)
if err != nil {
return nil, err
}
path = append(path, id)
}
//reverse the order to that the result is root tenant id -> current tenant id
//fi is index going forward starting from 0,
//ri is index going backward starting from last element
//swap the element at ri and ri
for fi, ri := 0, len(path)-1; fi < ri; fi, ri = fi+1, ri-1 {
path[fi], path[ri] = path[ri], path[fi]
}
return path, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tenancy
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/google/uuid"
)
const ZsetKey = "tenant-hierarchy"
const IsParentOfPredict = "is-parent-of"
const IsChildrenOfPredict = "is-children-of"
const RedisZsetMaxByte = "\uffff"
const RootTenantKey = "root-tenant-id"
const StatusKey = "tenant-hierarchy-status"
const STATUS_IN_PROGRESS = "IN_PROGRESS"
const STATUS_LOADED = "LOADED"
const STATUS_FAILED_TO_LOAD_ROOT_TENANT = "FAILED_TO_LOAD_ROOT_TENANT"
type Accessor interface {
GetParent(ctx context.Context, tenantId string) (string, error)
GetChildren(ctx context.Context, tenantId string) ([]string, error)
GetAncestors(ctx context.Context, tenantId string) ([]string, error)
GetDescendants(ctx context.Context, tenantId string) ([]string, error)
GetRoot(ctx context.Context) (string, error)
IsLoaded(ctx context.Context) bool
GetTenancyPath(ctx context.Context, tenantId string) ([]uuid.UUID, error)
}
// IsLoaded returns if tenancy information is available.
// Note that callers normally don't need to check this flag directly. Other top-level functions Get...() returns error if not loaded
func IsLoaded(ctx context.Context) bool {
return internalAccessor.IsLoaded(ctx)
}
func GetParent(ctx context.Context, tenantId string) (string, error) {
return internalAccessor.GetParent(ctx, tenantId)
}
func GetChildren(ctx context.Context, tenantId string) ([]string, error) {
return internalAccessor.GetChildren(ctx, tenantId)
}
func GetAncestors(ctx context.Context, tenantId string) ([]string, error) {
return internalAccessor.GetAncestors(ctx, tenantId)
}
func GetDescendants(ctx context.Context, tenantId string) ([]string, error) {
return internalAccessor.GetDescendants(ctx, tenantId)
}
/*
GetRoot because root tenantId won't change once system is started, we can cache it after first successful read.
*/
func GetRoot(ctx context.Context) (string, error) {
return internalAccessor.GetRoot(ctx)
}
func GetTenancyPath(ctx context.Context, tenantId string) ([]uuid.UUID, error) {
return internalAccessor.GetTenancyPath(ctx, tenantId)
}
// AnyHasDescendant returns true if any of "tenantIDs" in utils.StringSet contains "descendant" or its ancestors
func AnyHasDescendant(ctx context.Context, tenantIDs utils.StringSet, descendant string) bool {
if tenantIDs == nil || descendant == "" {
return false
}
if tenantIDs.Has(descendant) {
return true
}
ancestors, err := GetAncestors(ctx, descendant)
if err != nil {
return false
}
for _, ancestor := range ancestors {
if tenantIDs.Has(ancestor) {
return true
}
}
return false
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package th_loader
import "context"
type TenantHierarchyStore interface {
GetIterator(ctx context.Context) (TenantIterator, error)
}
type TenantIterator interface {
Next() bool
Scan(ctx context.Context) (Tenant, error)
Close() error
Err() error
}
type Tenant interface {
GetId() string
GetParentId() string //use pointer for nil
}
type Loader interface {
LoadTenantHierarchy(ctx context.Context) (err error)
}
func LoadTenantHierarchy(ctx context.Context) (err error) {
return internalLoader.LoadTenantHierarchy(ctx)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package th_loader
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/tenancy"
r "github.com/go-redis/redis/v8"
"github.com/google/uuid"
)
type TenancyLoader struct {
rc redis.Client
store TenantHierarchyStore
accessor tenancy.Accessor
}
func NewLoader(rc redis.Client, store TenantHierarchyStore, accessor tenancy.Accessor) *TenancyLoader {
return &TenancyLoader{
rc: rc,
store: store,
accessor: accessor,
}
}
func (l *TenancyLoader) LoadTenantHierarchy(ctx context.Context) (err error) {
//sets status to in progress
if cmd := l.rc.Set(ctx, tenancy.StatusKey, tenancy.STATUS_IN_PROGRESS, 0); cmd.Err() != nil {
return cmd.Err()
}
//clear out previously loaded data - this way in case the transaction below failed, we get empty data instead of stale data
if cmd := l.rc.Del(ctx, tenancy.ZsetKey); cmd.Err() != nil {
return cmd.Err()
}
if cmd := l.rc.Del(ctx, tenancy.RootTenantKey); cmd.Err() != nil {
return cmd.Err()
}
//deletes the zset, load its content and set status to loaded_{uuid}
//this delete is necessary because if two transaction blocks below runs sequentially, we don't want twice the data.
//this function is to be executed in the transaction below
var loadTenantHierarchy = func(tx *r.Tx) error {
cmd := tx.Del(ctx, tenancy.ZsetKey)
if cmd.Err() != nil {
return cmd.Err()
}
var relations []*r.Z
it, err := l.store.GetIterator(ctx)
if err != nil {
return err
}
defer func() { _ = it.Close() }()
for it.Next() {
t, err := it.Scan(ctx)
if err != nil {
return err
}
if t.GetParentId() != "" {
relations = append(relations,
&r.Z{Member: tenancy.BuildSpsString(t.GetId(), tenancy.IsChildrenOfPredict, t.GetParentId())},
&r.Z{Member: tenancy.BuildSpsString(t.GetParentId(), tenancy.IsParentOfPredict, t.GetId())})
} else {
statusCmd := tx.Set(ctx, tenancy.RootTenantKey, t.GetId(), 0)
if statusCmd.Err() != nil {
return statusCmd.Err()
}
}
}
// need to wait until root tenant exists
if cmd := tx.Get(ctx, tenancy.RootTenantKey); cmd.Err() != nil {
logger.WithContext(ctx).Errorf("Failed to load root tenant due to error: %v", cmd.Err())
if statusCmd := tx.Set(ctx, tenancy.StatusKey, tenancy.STATUS_FAILED_TO_LOAD_ROOT_TENANT, 0); statusCmd.Err() != nil {
logger.WithContext(ctx).Errorf("Failed to set status to STATUS_FAILED_TO_LOAD_ROOT_TENANT due to error: %v", statusCmd.Err())
return statusCmd.Err()
}
return cmd.Err()
}
if len(relations) != 0 {
cmd = tx.ZAdd(ctx, tenancy.ZsetKey, relations...)
if cmd.Err() != nil {
return cmd.Err()
}
}
statusCmd := tx.Set(ctx, tenancy.StatusKey, fmt.Sprintf("%s_%s", tenancy.STATUS_LOADED, uuid.New().String()), 0)
if statusCmd.Err() != nil {
return statusCmd.Err()
}
return nil
}
//watches the status key, if status changed, the transaction is aborted
err = l.rc.Watch(ctx, loadTenantHierarchy, tenancy.StatusKey)
if err != nil {
//we check if the failure is due to transaction aborted.
//if the status is loaded, that means this process failed because another auth server instance has loaded the
//successfully at the same time. If that's the case, we can start up.
loaded := l.accessor.IsLoaded(ctx)
if loaded {
return nil
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package th_loader
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"go.uber.org/fx"
)
var logger = log.New("Tenancy.Load")
var internalLoader Loader
var Module = &bootstrap.Module{
Name: "tenancy-loader",
Precedence: bootstrap.TenantHierarchyLoaderPrecedence,
Options: []fx.Option{
fx.Provide(defaultLoader()),
fx.Invoke(initializeTenantHierarchy),
},
}
const (
fxNameLoader = "tenant-hierarchy/loader"
)
func Use() {
tenancy.Use()
bootstrap.Register(Module)
}
type loaderDI struct {
fx.In
Ctx *bootstrap.ApplicationContext
Store TenantHierarchyStore
Cf redis.ClientFactory
Prop tenancy.CacheProperties
Accessor tenancy.Accessor `name:"tenancy/accessor"`
UnnamedLoader Loader `optional:"true"`
}
func defaultLoader() fx.Annotated {
return fx.Annotated{
Name: fxNameLoader,
Target: provideLoader,
}
}
func provideLoader(di loaderDI) Loader {
if di.UnnamedLoader != nil {
internalLoader = di.UnnamedLoader
return di.UnnamedLoader
}
rc, e := di.Cf.New(di.Ctx, func(opt *redis.ClientOption) {
opt.DbIndex = di.Prop.DbIndex
})
if e != nil {
panic(e)
}
internalLoader = NewLoader(rc, di.Store, di.Accessor)
return internalLoader
}
type initDi struct {
fx.In
AppCtx *bootstrap.ApplicationContext
EffectiveLoader Loader `name:"tenant-hierarchy/loader"`
}
func initializeTenantHierarchy(di initDi) error {
ctx := di.AppCtx
logger.WithContext(ctx).Infof("started loading tenant hierarchy")
internalLoader = di.EffectiveLoader
err := LoadTenantHierarchy(ctx)
if err != nil {
logger.WithContext(ctx).Errorf("tenant hierarchy not loaded due to %v", err)
} else {
logger.WithContext(ctx).Infof("finished loading tenant hierarchy")
}
return err
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package th_modifier
import (
"context"
)
type Modifier interface {
RemoveTenant(ctx context.Context, tenantId string) error
AddTenant(ctx context.Context, tenantId string, parentId string) error
}
func RemoveTenant(ctx context.Context, tenantId string) error {
return internalModifier.RemoveTenant(ctx, tenantId)
}
func AddTenant(ctx context.Context, tenantId string, parentId string) error {
return internalModifier.AddTenant(ctx, tenantId, parentId)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package th_modifier
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"github.com/cisco-open/go-lanai/pkg/utils"
r "github.com/go-redis/redis/v8"
)
type TenancyModifer struct {
rc redis.Client
accessor tenancy.Accessor
}
func newModifier(rc redis.Client, accessor tenancy.Accessor) *TenancyModifer {
return &TenancyModifer{
rc: rc,
accessor: accessor,
}
}
func (m *TenancyModifer) RemoveTenant(ctx context.Context, tenantId string) error {
if tenantId == "" {
return errors.New("tenantId should not be empty")
}
logger.Debugf("remove tenantId %s", tenantId)
children, err := m.accessor.GetChildren(ctx, tenantId)
if err != nil {
return err
}
if len(children) != 0 {
return errors.New("can't remove tenant that still have children")
}
parentId, err := m.accessor.GetParent(ctx, tenantId)
if err != nil {
return err
}
if parentId == "" {
return errors.New("this tenant is root tenant because it has no parent. root tenant can't be deleted")
}
relations := []interface{}{
tenancy.BuildSpsString(tenantId, tenancy.IsChildrenOfPredict, parentId),
tenancy.BuildSpsString(parentId, tenancy.IsParentOfPredict, tenantId)}
cmd := m.rc.ZRem(ctx, tenancy.ZsetKey, relations...)
return cmd.Err()
}
func (m *TenancyModifer) AddTenant(ctx context.Context, tenantId string, parentId string) error {
if tenantId == "" || parentId == "" {
return errors.New("tenantId and parentId should not be empty")
}
logger.Debugf("add tenantId %s parentId %s", tenantId, parentId)
p, err := m.accessor.GetParent(ctx, tenantId)
if err != nil {
return err
}
if p != "" {
return errors.New("this tenant already have a parent")
}
root, err := m.accessor.GetRoot(ctx)
if err != nil {
return err
}
if tenantId == root {
return errors.New("this tenant is the root")
}
ancestors, err := m.accessor.GetAncestors(ctx, parentId)
if err != nil {
return err
}
set := utils.NewStringSet(ancestors...)
if set.Has(tenantId) || tenantId == parentId {
return errors.New("this relationship introduces a cycle in the tenant hierarchy")
}
relations := []*r.Z{
{Member: tenancy.BuildSpsString(tenantId, tenancy.IsChildrenOfPredict, parentId)},
{Member: tenancy.BuildSpsString(parentId, tenancy.IsParentOfPredict, tenantId)}}
cmd := m.rc.ZAdd(ctx, tenancy.ZsetKey, relations...)
return cmd.Err()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package th_modifier
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/redis"
"github.com/cisco-open/go-lanai/pkg/tenancy"
"go.uber.org/fx"
)
var logger = log.New("Tenancy.Modify")
var internalModifier Modifier
var Module = &bootstrap.Module{
Name: "tenancy-modifier",
Precedence: bootstrap.TenantHierarchyModifierPrecedence,
Options: []fx.Option{
fx.Provide(provideModifier),
fx.Invoke(setup),
},
}
func Use() {
tenancy.Use()
bootstrap.Register(Module)
}
type modifierDI struct {
fx.In
Ctx *bootstrap.ApplicationContext
Cf redis.ClientFactory
Prop tenancy.CacheProperties
Accessor tenancy.Accessor `name:"tenancy/accessor"`
}
func provideModifier(di modifierDI) Modifier {
rc, e := di.Cf.New(di.Ctx, func(opt *redis.ClientOption) {
opt.DbIndex = di.Prop.DbIndex
})
if e != nil {
panic(e)
}
internalModifier = newModifier(rc, di.Accessor)
return internalModifier
}
func setup(_ Modifier) {
// currently, keep everything default
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tenancy
import (
"errors"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/redis"
"go.uber.org/fx"
)
var internalAccessor Accessor
var Module = &bootstrap.Module{
Name: "tenant-hierarchy",
Precedence: bootstrap.TenantHierarchyAccessorPrecedence,
Options: []fx.Option{
fx.Provide(bindCacheProperties),
fx.Provide(defaultTenancyAccessorProvider()),
fx.Invoke(setup),
},
}
const (
fxNameAccessor = "tenancy/accessor"
)
func Use() {
bootstrap.Register(Module)
}
type defaultDI struct {
fx.In
Ctx *bootstrap.ApplicationContext
Cf redis.ClientFactory `optional:"true"`
Prop CacheProperties `optional:"true"`
UnnamedTenancyAccessor Accessor `optional:"true"`
}
func defaultTenancyAccessorProvider() fx.Annotated {
return fx.Annotated{
Name: fxNameAccessor,
Target: provideAccessor,
}
}
func provideAccessor(di defaultDI) Accessor {
if di.UnnamedTenancyAccessor != nil {
internalAccessor = di.UnnamedTenancyAccessor
return di.UnnamedTenancyAccessor
}
if di.Cf == nil {
panic(errors.New("redis client factory is required"))
}
rc, e := di.Cf.New(di.Ctx, func(opt *redis.ClientOption) {
opt.DbIndex = di.Prop.DbIndex
})
if e != nil {
panic(e)
}
internalAccessor = newAccessor(rc)
return internalAccessor
}
type setupDI struct {
fx.In
EffectiveAccessor Accessor `name:"tenancy/accessor"`
}
func setup(_ setupDI) {
// keep it as default for now
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tenancy
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/pkg/errors"
)
/***********************
Cache
************************/
const CachePropertiesPrefix = "security.cache"
type CacheProperties struct {
DbIndex int `json:"db-index"`
}
func newCacheProperties() *CacheProperties {
return &CacheProperties{}
}
func bindCacheProperties(ctx *bootstrap.ApplicationContext) CacheProperties {
props := newCacheProperties()
if err := ctx.Config().Bind(props, CachePropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind CacheProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"fmt"
th_loader "github.com/cisco-open/go-lanai/pkg/tenancy/loader"
"github.com/ghodss/yaml"
"github.com/google/uuid"
"io"
"io/fs"
)
type TestData struct {
Tenants []TestTenant `json:"tenants"`
UUIDs map[string]uuid.UUID `json:"uuids"`
}
type TestTenant struct {
Name string `json:"name"`
Parent string `json:"parent"`
uuidMapping map[string]uuid.UUID
}
func (t TestTenant) GetId() string {
if len(t.Name) == 0 {
return ""
}
return t.uuidMapping[t.Name].String()
}
func (t TestTenant) GetParentId() string {
if len(t.Parent) == 0 {
return ""
}
return t.uuidMapping[t.Parent].String()
}
type TestTenantStore struct {
TestData
SourceFS fs.FS
SourcePath string
}
func (s *TestTenantStore) Reset(srcFS fs.FS, srcPath string) {
s.Tenants = nil
s.UUIDs = nil
s.SourceFS = srcFS
s.SourcePath = srcPath
}
func (s *TestTenantStore) IDof(tenant string) string {
if len(tenant) == 0 || s.UUIDs == nil {
return ""
}
return s.UUIDs[tenant].String()
}
func (s *TestTenantStore) GetIterator(_ context.Context) (th_loader.TenantIterator, error) {
if len(s.SourcePath) == 0 || s.SourceFS == nil {
return &TestTenantIterator{Tenants: []TestTenant{}}, nil
}
if len(s.Tenants) == 0 {
data, e := fs.ReadFile(s.SourceFS, s.SourcePath)
if e != nil {
return nil, fmt.Errorf("unable to load test tenants file: %v", e)
}
if e := yaml.Unmarshal(data, &s.TestData); e != nil {
return nil, fmt.Errorf("unable to parse test tenants file: %v", e)
}
for i := range s.Tenants {
s.Tenants[i].uuidMapping = s.UUIDs
}
}
return &TestTenantIterator{Tenants: s.Tenants}, nil
}
type TestTenantIterator struct {
Tenants []TestTenant
}
func (i *TestTenantIterator) Next() bool {
return len(i.Tenants) != 0
}
func (i *TestTenantIterator) Scan(_ context.Context) (th_loader.Tenant, error) {
if len(i.Tenants) == 0 {
return nil, io.EOF
}
defer func() {
i.Tenants = i.Tenants[1:]
}()
return i.Tenants[0], nil
}
func (i *TestTenantIterator) Close() error {
return nil
}
func (i *TestTenantIterator) Err() error {
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tenancy
import (
"errors"
"fmt"
"strings"
)
const spoPrefix = "spo"
func BuildSpsString(subject string, predict string, object... string) string {
if len(object) == 0 {
return fmt.Sprintf("%s:%s:%s", spoPrefix, subject, predict)
} else {
return fmt.Sprintf("%s:%s:%s:%s", spoPrefix, subject, predict, object[0])
}
}
func GetObjectOfSpo(spo string) (string, error) {
parts := strings.Split(spo, ":")
if len(parts) == 4 {
return parts[3], nil
} else {
return "", errors.New("spo relation has no object part")
}
}
func ZInclusive(min string) string {
return fmt.Sprintf("[%s", min)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
/**********************
SpanOptions
**********************/
func SpanTag(key string, v interface{}) SpanOption {
return func(span opentracing.Span) {
span.SetTag(key, v)
}
}
func SpanBaggageItem(restrictedKey string, s string) SpanOption {
return func(span opentracing.Span) {
span.SetBaggageItem(restrictedKey, s)
}
}
func SpanKind(v ext.SpanKindEnum) SpanOption {
return func(span opentracing.Span) {
ext.SpanKind.Set(span, v)
}
}
func SpanComponent(v string) SpanOption {
return func(span opentracing.Span) {
ext.Component.Set(span, v)
}
}
func SpanHttpUrl(v string) SpanOption {
return func(span opentracing.Span) {
ext.HTTPUrl.Set(span, v)
}
}
func SpanHttpMethod(v string) SpanOption {
return func(span opentracing.Span) {
ext.HTTPMethod.Set(span, v)
}
}
func SpanHttpStatusCode(v int) SpanOption {
return func(span opentracing.Span) {
ext.HTTPStatusCode.Set(span, uint16(v))
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"context"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/opentracing/opentracing-go"
"time"
)
type spanKey struct{}
var spanFinisherKey = spanKey{}
// DefaultLogValuers is used by log package to extract tracing information in log templates.
// This variable is properly set by "tracing/init".
var DefaultLogValuers = LogValuers{
TraceIDValuer: func(context.Context) interface{} { return nil },
SpanIDValuer: func(context.Context) interface{} { return nil },
ParentIDValuer: func(context.Context) interface{} { return nil },
}
type SpanOption func(opentracing.Span)
type SpanRewinder func() context.Context
/**********************
Context
**********************/
type LogValuers struct {
TraceIDValuer log.ContextValuer
SpanIDValuer log.ContextValuer
ParentIDValuer log.ContextValuer
}
func (v LogValuers) ContextValuers() log.ContextValuers {
return log.ContextValuers{
"traceId": v.TraceIDValuer,
"spanId": v.SpanIDValuer,
"parentId": v.ParentIDValuer,
}
}
//nolint:contextcheck
func SpanFromContext(ctx context.Context) (span opentracing.Span) {
span = opentracing.SpanFromContext(ctx)
if span != nil {
return
}
// try to get from Request's context if given context contains gin.Context
if gc := web.GinContext(ctx); gc != nil {
span = opentracing.SpanFromContext(gc.Request.Context())
}
return
}
func SpanRewinderFromContext(ctx context.Context) SpanRewinder {
if finisher, ok := ctx.Value(spanFinisherKey).(SpanRewinder); ok {
return finisher
}
// try to get from Request's context if given context contains gin.Context
if gc := web.GinContext(ctx); gc != nil {
if finisher, ok := gc.Request.Context().Value(spanFinisherKey).(SpanRewinder); ok {
return finisher
}
}
return nil
}
func ContextWithSpanRewinder(ctx context.Context, finisher SpanRewinder) context.Context {
return context.WithValue(ctx, spanFinisherKey, finisher)
}
func TraceIdFromContext(ctx context.Context) (ret interface{}) {
return DefaultLogValuers.TraceIDValuer(ctx)
}
func SpanIdFromContext(ctx context.Context) (ret interface{}) {
return DefaultLogValuers.SpanIDValuer(ctx)
}
func ParentIdFromContext(ctx context.Context) (ret interface{}) {
return DefaultLogValuers.ParentIDValuer(ctx)
}
/**********************
Span Operators
**********************/
type SpanOperator struct {
tracer opentracing.Tracer
name string
startOptions []opentracing.StartSpanOption
updateOptions []SpanOption
finishOptions opentracing.FinishOptions
}
func WithTracer(tracer opentracing.Tracer) *SpanOperator {
return &SpanOperator{
tracer: tracer,
startOptions: []opentracing.StartSpanOption{},
updateOptions: []SpanOption{},
}
}
// Setters
func (op *SpanOperator) WithOpName(name string) *SpanOperator {
op.name = name
return op
}
func (op *SpanOperator) WithStartOptions(options ...opentracing.StartSpanOption) *SpanOperator {
op.startOptions = append(op.startOptions, options...)
return op
}
func (op *SpanOperator) WithOptions(exts ...SpanOption) *SpanOperator {
op.updateOptions = append(op.updateOptions, exts...)
return op
}
// Operations
func (op *SpanOperator) UpdateCurrentSpan(ctx context.Context) {
span := SpanFromContext(ctx)
if span == nil {
return
}
op.applyUpdateOptions(span)
return
}
// Finish finish current span if exist.
// Note: The finished span is still counted as "current span".
//
// If caller want to rewind to previous span, use FinishAndRewind instead
func (op *SpanOperator) Finish(ctx context.Context) {
if span := SpanFromContext(ctx); span != nil {
op.applyUpdateOptions(span)
op.finishOptions.FinishTime = time.Now().UTC()
span.FinishWithOptions(op.finishOptions)
}
}
// FinishAndRewind finish current span if exist and restore context with parent span if possible (no garantees)
// callers shall not continue to use the old context after this call
// Note: all values in given context added during the current span will be lost. It's like rewind operation
func (op *SpanOperator) FinishAndRewind(ctx context.Context) context.Context {
op.Finish(ctx)
rewinder := SpanRewinderFromContext(ctx)
if rewinder == nil {
return ctx
}
return rewinder()
}
// NewSpanOrDescendant create new span if not currently have one,
// spawn a child span using opentracing.ChildOf(span.Context()) if span exists
func (op *SpanOperator) NewSpanOrDescendant(ctx context.Context) context.Context {
return op.newSpan(ctx, func(span opentracing.Span) opentracing.SpanReference {
return opentracing.ChildOf(span.Context())
}, true)
}
// NewSpanOrFollows create new span if not currently have one,
// spawn a child span using opentracing.FollowsFrom(span.Context()) if span exists
func (op *SpanOperator) NewSpanOrFollows(ctx context.Context) context.Context {
return op.newSpan(ctx, func(span opentracing.Span) opentracing.SpanReference {
return opentracing.FollowsFrom(span.Context())
}, true)
}
// DescendantOrNoSpan spawn a child span using opentracing.ChildOf(span.Context()) if there is a span exists
// otherwise do nothing
func (op *SpanOperator) DescendantOrNoSpan(ctx context.Context) context.Context {
return op.newSpan(ctx, func(span opentracing.Span) opentracing.SpanReference {
return opentracing.ChildOf(span.Context())
}, false)
}
// FollowsOrNoSpan spawn a child span using opentracing.FollowsFrom(span.Context()) if there is a span exists
// otherwise do nothing
func (op *SpanOperator) FollowsOrNoSpan(ctx context.Context) context.Context {
return op.newSpan(ctx, func(span opentracing.Span) opentracing.SpanReference {
return opentracing.FollowsFrom(span.Context())
}, false)
}
// ForceNewSpan force to create a new span and discard any existing span
// Warning: Internal usage, use with caution
func (op *SpanOperator) ForceNewSpan(ctx context.Context) context.Context {
return op.newSpan(ctx, nil, true)
}
func (op *SpanOperator) createSpanRewinder(ctx context.Context) SpanRewinder {
return func() context.Context {
return ctx
}
}
type spanReferencer func(opentracing.Span) opentracing.SpanReference
func (op *SpanOperator) newSpan(ctx context.Context, referencer spanReferencer, must bool) context.Context {
span := SpanFromContext(ctx)
rewinder := op.createSpanRewinder(ctx)
switch {
case span != nil:
options := op.startOptions
if referencer != nil {
options = append([]opentracing.StartSpanOption{referencer(span)}, options...)
}
span = op.tracer.StartSpan(op.name, options...)
case must:
span = op.tracer.StartSpan(op.name, op.startOptions...)
default:
return ctx
}
op.applyUpdateOptions(span)
return opentracing.ContextWithSpan(ContextWithSpanRewinder(ctx, rewinder), span)
}
func (op *SpanOperator) applyUpdateOptions(span opentracing.Span) {
for _, ext := range op.updateOptions {
ext(span)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"context"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/cisco-open/go-lanai/pkg/tracing/instrument"
jaegertracing "github.com/cisco-open/go-lanai/pkg/tracing/jaeger"
"github.com/opentracing/opentracing-go"
"go.uber.org/fx"
)
var logger = log.New("Tracing")
var Module = &bootstrap.Module{
Name: "Tracing",
Precedence: bootstrap.TracingPrecedence,
PriorityOptions: []fx.Option{
fx.Provide(tracing.BindTracingProperties),
fx.Provide(provideTracer),
fx.Provide(instrument.CliRunnerTracingProvider()),
fx.Invoke(initialize),
},
}
func init() {
log.RegisterContextLogFields(tracing.DefaultLogValuers.ContextValuers())
}
// Use does nothing. Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
EnableBootstrapTracing(bootstrap.GlobalBootstrapper())
}
type TracerClosingHook *fx.Hook
var defaultTracerCloser fx.Hook
type kCtxDefaultTracerCloser struct {}
// EnableBootstrapTracing enable bootstrap tracing on a given bootstrapper.
// bootstrap.GlobalBootstrapper() should be used for regular application that uses bootstrap.Execute()
func EnableBootstrapTracing(bootstrapper *bootstrap.Bootstrapper) {
appTracer, closer := jaegertracing.NewDefaultTracer()
instrument.EnableBootstrapTracing(bootstrapper, appTracer)
defaultTracerCloser = fx.Hook{
OnStop: func(ctx context.Context) error {
logger.WithContext(ctx).Infof("closing default Tracer...")
e := closer.Close()
if e != nil {
logger.WithContext(ctx).Errorf("failed to close default Tracer: %v", e)
}
logger.WithContext(ctx).Infof("default Tracer closed")
return e
},
}
bootstrapper.AddInitialAppContextOptions(func(ctx context.Context) context.Context {
return context.WithValue(ctx, kCtxDefaultTracerCloser{}, defaultTracerCloser)
})
}
/**************************
Provide dependencies
***************************/
type tracerOut struct {
fx.Out
Tracer opentracing.Tracer
FxHook TracerClosingHook
}
func provideTracer(ctx *bootstrap.ApplicationContext, props tracing.TracingProperties) (ret tracerOut) {
ret = tracerOut{
Tracer: opentracing.NoopTracer{},
}
if !props.Enabled {
return
}
tracers := make([]opentracing.Tracer, 0, 2)
if props.Jaeger.Enabled {
tracer, closer := jaegertracing.NewTracer(ctx, &props.Jaeger, &props.Sampler)
tracers = append(tracers, tracer)
ret.FxHook = &fx.Hook{
OnStop: func(ctx context.Context) error {
logger.WithContext(ctx).Infof("closing Jaeger Tracer...")
e := closer.Close()
if e != nil {
logger.WithContext(ctx).Errorf("failed to close Jaeger Tracer: %v", e)
}
logger.WithContext(ctx).Infof("Jaeger Tracer closed")
return e
},
}
}
if props.Zipkin.Enabled {
panic("zipkin is currently unsupported")
}
switch len(tracers) {
case 0:
return
case 1:
ret.Tracer = tracers[0]
return
default:
panic("multiple opentracing.Tracer detected. we currely only support single tracer")
}
}
/**************************
Setup
***************************/
type regDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
Tracer opentracing.Tracer `optional:"true"`
FxHook TracerClosingHook `optional:"true"`
// we could include security configurations, customizations here
}
func initialize(lc fx.Lifecycle, di regDI) {
if di.Tracer == nil {
return
}
// graceful closer
if di.FxHook != nil {
lc.Append(*di.FxHook)
if defaultCloserFromCtx, ok := di.AppContext.Value(kCtxDefaultTracerCloser{}).(fx.Hook); ok {
lc.Append(defaultCloserFromCtx)
} else {
lc.Append(defaultTracerCloser)
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package instrument
import (
"context"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"go.uber.org/fx"
)
const opNameCli = "cli"
type cliRunnerTracingHooks struct {
tracer opentracing.Tracer
}
func CliRunnerTracingProvider() fx.Annotated {
return fx.Annotated{
Group: bootstrap.FxCliRunnerGroup,
Target: newCliRunnerTracingHooks,
}
}
func newCliRunnerTracingHooks(tracer opentracing.Tracer) bootstrap.CliRunnerLifecycleHooks {
return &cliRunnerTracingHooks{tracer: tracer}
}
func (h cliRunnerTracingHooks) Before(ctx context.Context, runner bootstrap.CliRunner) context.Context {
return tracing.WithTracer(h.tracer).
WithOpName(opNameCli).
WithOptions(tracing.SpanKind(ext.SpanKindRPCServerEnum)).
//WithOptions(tracing.SpanTag("runner", fmt.Sprintf("%v", reflect.ValueOf(runner).String()))).
ForceNewSpan(ctx)
}
func (h cliRunnerTracingHooks) After(ctx context.Context, runner bootstrap.CliRunner, err error) context.Context {
op := tracing.WithTracer(h.tracer)
if err != nil {
op = op.WithOptions(tracing.SpanTag("err", err))
}
return op.FinishAndRewind(ctx)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package instrument
import (
"context"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
const (
opNameBootstrap = "bootstrap"
opNameStart = "startup"
opNameStop = "shutdown"
)
func EnableBootstrapTracing(bootstrapper *bootstrap.Bootstrapper, tracer opentracing.Tracer) {
bootstrapper.AddInitialAppContextOptions(MakeBootstrapTracingOption(tracer, opNameBootstrap))
bootstrapper.AddStartContextOptions(MakeStartTracingOption(tracer, opNameStart))
bootstrapper.AddStopContextOptions(MakeStopTracingOption(tracer, opNameStop))
}
func MakeBootstrapTracingOption(tracer opentracing.Tracer, opName string) bootstrap.ContextOption {
return func(ctx context.Context) context.Context {
return tracing.WithTracer(tracer).
WithOpName(opName).
WithOptions(tracing.SpanKind(ext.SpanKindRPCServerEnum)).
NewSpanOrDescendant(ctx)
}
}
func MakeStartTracingOption(tracer opentracing.Tracer, opName string) bootstrap.ContextOption {
return func(ctx context.Context) context.Context {
return tracing.WithTracer(tracer).
WithOpName(opName).
WithOptions(tracing.SpanKind(ext.SpanKindRPCServerEnum)).
NewSpanOrDescendant(ctx)
}
}
func MakeStopTracingOption(tracer opentracing.Tracer, opName string) bootstrap.ContextOption {
return func(ctx context.Context) context.Context {
// finish current if not root span and start a new child
return tracing.WithTracer(tracer).
WithOpName(opName).
WithOptions(tracing.SpanKind(ext.SpanKindRPCServerEnum)).
NewSpanOrDescendant(ctx)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jaegertracing
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/uber/jaeger-client-go"
)
func init() {
tracing.DefaultLogValuers = tracing.LogValuers{
TraceIDValuer: traceIdContextValuer,
SpanIDValuer: spanIdContextValuer,
ParentIDValuer: parentIdContextValuer,
}
}
func traceIdContextValuer(ctx context.Context) (ret interface{}) {
span := tracing.SpanFromContext(ctx)
if span == nil {
return
}
switch span.Context().(type) {
case jaeger.SpanContext:
ret = jaegerTraceIdString(span.Context().(jaeger.SpanContext).TraceID())
default:
return
}
return
}
func spanIdContextValuer(ctx context.Context) (ret interface{}) {
span := tracing.SpanFromContext(ctx)
if span == nil {
return
}
switch span.Context().(type) {
case jaeger.SpanContext:
ret = jaegerSpanIdString(span.Context().(jaeger.SpanContext).SpanID())
}
return
}
func parentIdContextValuer(ctx context.Context) (ret interface{}) {
span := tracing.SpanFromContext(ctx)
if span == nil {
return
}
switch span.Context().(type) {
case jaeger.SpanContext:
ret = jaegerSpanIdString(span.Context().(jaeger.SpanContext).ParentID())
default:
return
}
return
}
func jaegerTraceIdString(id jaeger.TraceID) string {
if !id.IsValid() {
return ""
}
if id.High == 0 {
return fmt.Sprintf("%.16x", id.Low)
}
return fmt.Sprintf("%.16x%016x", id.High, id.Low)
}
func jaegerSpanIdString(id jaeger.SpanID) string {
if id != 0 {
return fmt.Sprintf("%.16x", uint64(id))
}
return ""
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jaegertracing
import (
"fmt"
"github.com/opentracing/opentracing-go"
"github.com/uber/jaeger-client-go"
"github.com/uber/jaeger-client-go/zipkin"
"strconv"
"strings"
)
// Option is a function that sets an option on Propagator
type Option func(propagator *Propagator)
// BaggagePrefix is a function that sets baggage prefix on Propagator
//goland:noinspection GoUnusedExportedFunction
func BaggagePrefix(prefix string) Option {
return func(propagator *Propagator) {
propagator.baggagePrefix = prefix
zipkin.BaggagePrefix(prefix)(&propagator.delegate)
}
}
func SingleHeader() Option {
return func(propagator *Propagator) {
propagator.singleHeader = true
}
}
// Propagator is an extension of zipkin.Propagator that support Single Header propagation:
// See https://github.com/openzipkin/b3-propagation#single-header
type Propagator struct {
delegate zipkin.Propagator
baggagePrefix string
singleHeader bool
}
// NewZipkinB3Propagator creates a Propagator for extracting and injecting Zipkin B3 headers into SpanContexts.
// Baggage is by default enabled and uses prefix 'baggage-'.
func NewZipkinB3Propagator(opts ...Option) *Propagator {
p := Propagator{
delegate: zipkin.NewZipkinB3HTTPHeaderPropagator(),
}
for _, fn := range opts {
fn(&p)
}
return &p
}
// Inject conforms to the Injector interface for decoding Zipkin B3 headers
func (p Propagator) Inject(sc jaeger.SpanContext, abstractCarrier interface{}) error {
if !p.singleHeader {
return p.delegate.Inject(sc, abstractCarrier)
}
// single header
textMapWriter, ok := abstractCarrier.(opentracing.TextMapWriter)
if !ok {
return opentracing.ErrInvalidCarrier
}
// https://github.com/openzipkin/b3-propagation#single-header
// b3={TraceId}-{SpanId}-{SamplingState}-{ParentSpanId}, where the last two fields are optional.
values := []string{
sc.TraceID().String(),
sc.SpanID().String(),
}
if sc.IsSampled() {
values = append(values, "1")
} else {
values = append(values, "0")
}
if sc.ParentID() != 0 {
values = append(values, strconv.FormatUint(uint64(sc.ParentID()), 16))
}
textMapWriter.Set("b3", strings.Join(values, "-"))
sc.ForeachBaggageItem(func(k, v string) bool {
textMapWriter.Set(p.baggagePrefix+k, v)
return true
})
return nil
}
// Extract conforms to the Extractor interface for encoding Zipkin HTTP B3 headers
func (p Propagator) Extract(abstractCarrier interface{}) (jaeger.SpanContext, error) {
if !p.singleHeader {
return p.delegate.Extract(abstractCarrier)
}
textMapReader, ok := abstractCarrier.(opentracing.TextMapReader)
if !ok {
return jaeger.SpanContext{}, opentracing.ErrInvalidCarrier
}
var traceID jaeger.TraceID
var spanID jaeger.SpanID
var parentID uint64
sampled := false
var baggage map[string]string
err := textMapReader.ForeachKey(func(rawKey, value string) error {
key := strings.ToLower(rawKey) // TODO not necessary for plain TextMap
if strings.HasPrefix(key, p.baggagePrefix) {
if baggage == nil {
baggage = make(map[string]string)
}
baggage[key[len(p.baggagePrefix):]] = value
}
if key != "b3" {
return nil
}
// https://github.com/openzipkin/b3-propagation#single-header
// b3={TraceId}-{SpanId}-{SamplingState}-{ParentSpanId}, where the last two fields are optional.
splits := strings.SplitN(value, "-", 4)
if len(splits) < 2 {
return fmt.Errorf("invalid b3 value")
}
var e error
if traceID, e = jaeger.TraceIDFromString(splits[0]); e != nil {
return e
}
if spanID, e = jaeger.SpanIDFromString(splits[1]); e != nil {
return e
}
if len(splits) >= 3 {
if sampled, e = strconv.ParseBool(splits[2]); e != nil {
return e
}
}
if len(splits) >= 4 {
if parentID, e = strconv.ParseUint(splits[3], 16, 64); e != nil {
return e
}
}
return e
})
switch {
case err != nil:
return jaeger.SpanContext{}, err
case !traceID.IsValid():
return jaeger.SpanContext{}, opentracing.ErrSpanContextNotFound
default:
return jaeger.NewSpanContext(traceID, spanID, jaeger.SpanID(parentID), sampled, baggage), nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package jaegertracing
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/uber/jaeger-client-go"
"io"
)
var logger = log.New("Tracing")
func NewDefaultTracer() (opentracing.Tracer, io.Closer) {
return newTracer("lanai", jaeger.NewConstSampler(false), jaeger.NewNullReporter())
}
func NewTracer(ctx *bootstrap.ApplicationContext, jp *tracing.JaegerProperties, sp *tracing.SamplerProperties) (opentracing.Tracer, io.Closer) {
name := ctx.Name()
sampler := newSampler(ctx, sp)
reporter := newReporter(ctx, jp, sp)
return newTracer(name, sampler, reporter)
}
// newTracer we use B3 single header compatible format, this is compatible with Spring-Sleuth powered services
// See https://github.com/openzipkin/b3-propagation#single-header
// See https://github.com/jaegertracing/jaeger-client-go/blob/master/zipkin/README.md#NewZipkinB3HTTPHeaderPropagator
func newTracer(serviceName string, sampler jaeger.Sampler, reporter jaeger.Reporter,) (opentracing.Tracer, io.Closer) {
b3HttpPropagator := NewZipkinB3Propagator()
b3SingleHeaderPropagator := NewZipkinB3Propagator(SingleHeader())
zipkinOpts := []jaeger.TracerOption {
jaeger.TracerOptions.Injector(opentracing.HTTPHeaders, b3HttpPropagator),
jaeger.TracerOptions.Injector(opentracing.TextMap, b3SingleHeaderPropagator),
jaeger.TracerOptions.Extractor(opentracing.HTTPHeaders, b3HttpPropagator),
jaeger.TracerOptions.Extractor(opentracing.TextMap, b3SingleHeaderPropagator),
// Zipkin shares span ID between client and server spans; it must be enabled via the following option.
jaeger.TracerOptions.ZipkinSharedRPCSpan(true),
}
return jaeger.NewTracer(serviceName, sampler, reporter, zipkinOpts...)
}
func newSampler(ctx context.Context, sp *tracing.SamplerProperties) jaeger.Sampler {
if !sp.Enabled {
return jaeger.NewConstSampler(false)
}
if sp.LowestRate > 0 && sp.Probability > 0 && sp.Probability <= 1.0 {
sampler, e := jaeger.NewGuaranteedThroughputProbabilisticSampler(sp.LowestRate, sp.Probability)
if e == nil {
logger.WithContext(ctx).
Infof("Use GuaranteedThroughputProbabilisticSampler with lowest rate %.3f/s and probability %%%2.1f",
sp.LowestRate, sp.Probability * 100)
return sampler
}
}
if sp.Probability > 0 && sp.Probability <= 1.0 {
sampler, e := jaeger.NewProbabilisticSampler(sp.Probability)
if e == nil {
logger.WithContext(ctx).
Infof("Use ProbabilisticSampler with lprobability %%%2.1f", sp.Probability * 100)
return sampler
}
}
if sp.RateLimit > 0 {
sampler := jaeger.NewRateLimitingSampler(sp.RateLimit)
logger.WithContext(ctx).
Infof("Use RateLimitingSampler with rate limit %.3f/s", sp.RateLimit)
return sampler
}
logger.WithContext(ctx).Warnf("both rate limit and probability are not valid, tracing sampling is disabled")
return jaeger.NewConstSampler(false)
}
func newReporter(ctx context.Context, jp *tracing.JaegerProperties, sp *tracing.SamplerProperties) jaeger.Reporter {
if !sp.Enabled || jp.Host == "" || jp.Port == 0 {
return jaeger.NewNullReporter()
}
hostPort := fmt.Sprintf("%s:%d", jp.Host, jp.Port)
transport, e := jaeger.NewUDPTransport(hostPort, 0)
if e != nil {
panic(fmt.Sprintf("unable to estabilish connection to Jaeger server at %s", hostPort))
}
return jaeger.NewRemoteReporter(transport)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/pkg/errors"
)
const (
TracingPropertiesPrefix = "tracing"
)
type TracingProperties struct {
Enabled bool `json:"enabled"`
Jaeger JaegerProperties `json:"jaeger"`
Zipkin ZipkinProperties `json:"zipkin"`
Sampler SamplerProperties `json:"sampler"`
}
type JaegerProperties struct {
Enabled bool `json:"enabled"`
Host string `json:"host"`
Port int `json:"port"`
}
type ZipkinProperties struct {
Enabled bool `json:"enabled"`
}
type SamplerProperties struct {
Enabled bool `json:"enabled"`
RateLimit float64 `json:"limit-per-second"`
Probability float64 `json:"probability"`
LowestRate float64 `json:"lowest-per-second"`
}
// NewTracingProperties create a SessionProperties with default values
func NewTracingProperties() *TracingProperties {
return &TracingProperties{
Enabled: true,
Jaeger: JaegerProperties{
Enabled: true,
},
Zipkin: ZipkinProperties{},
Sampler: SamplerProperties{
Enabled: false,
RateLimit: 10.0,
},
}
}
// BindTracingProperties create and bind SessionProperties, with a optional prefix
func BindTracingProperties(ctx *bootstrap.ApplicationContext) TracingProperties {
props := NewTracingProperties()
if err := ctx.Config().Bind(props, TracingPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind TracingProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package cacheutils
import (
"context"
"fmt"
"reflect"
"sync"
"sync/atomic"
"time"
)
type Key interface{
// Hash is used for internal mapping. It has to be unique and non-pointer
Hash() interface{}
}
type StringKey string
func (s StringKey) Hash() interface{} {
return s
}
func (s StringKey) String() string {
return string(s)
}
type MemCache interface {
// GetOrLoad try to get cached entry, using provided validator to check the entry, if not valid, try to load it.
// If there is any error during load, it's cached and returned from this method
// Note: this method is the main method for this in-memory cache
GetOrLoad(ctx context.Context, k Key, loader LoadFunc, validator ValidateFunc) (interface{}, error)
// Update Utility method, force change the loaded entry's value.
// If given key doesn't exist or is invalidated, this function does nothing and return false, otherwise returns true.
// If there is any error during update, it's cached and returned from this method
// If this Update is used while the entry is loading, it will wait until loading finishes and then perform update
Update(ctx context.Context, k Key, updater UpdateFunc) (bool, error)
// Delete Utility method, remove the cached entry of given key, regardless if it's valid
Delete(k Key)
// Reset Utility method, reset the cache and remove all entries, regardless if they are valid
Reset()
// Evict Utility method, cleanup the cache, removing any invalid entries, free up memory
// Note: this process is also performed periodically, normally there is no need to call this function manually
Evict()
}
type LoadFunc func(ctx context.Context, k Key) (v interface{}, exp time.Time, err error)
type UpdateFunc func(ctx context.Context, k Key, old interface{}) (v interface{}, exp time.Time, err error)
type ValidateFunc func(ctx context.Context, v interface{}) bool
type CacheOptions func(opt *CacheOption)
type CacheOption struct {
Heartbeat time.Duration
LoadRetry int
}
// cEntry carries cache entry.
// after the sync.WaitGroup's Wait() func, value, expire and lastErr should be immutable
// and isLoaded() should return true
type cEntry struct {
wg sync.WaitGroup
value interface{}
expire time.Time
lastErr error
// invalid indicates whether "get" function should return it as existing entry.
// once an entry become "invalid", it's equivalent to "not exist"
// invalid can only be set from False to True atomically.
// when invalid flag == 1, it's guaranteed that the entry is not valid and such status is immutable
// when invalid flag == 0, it's NOT guaranteed the entry is "valid", goroutines should also check other fields after sync.WaitGroup's Wait()
invalid uint64
// loaded is used for evicting function to decide if expire is available without waiting on loader
// because evicting func is executed periodically to act on "loaded" entries, and loaded can only be set from False to True,
// it's not necessary to use lock to coordinate, atomic op is sufficient
// other threads/goroutines should use sync.WaitGroup's Wait()
loaded uint64
}
// isExpired is NOT goroutine-safe
func (ce *cEntry) isExpired() bool {
return !ce.expire.IsZero() && !time.Now().Before(ce.expire)
}
// isInvalidated is atomic operation
func (ce *cEntry) isInvalidated() bool {
return atomic.LoadUint64(&ce.invalid) != 0
}
// invalidate is atomic operation
func (ce *cEntry) invalidate() {
atomic.StoreUint64(&ce.invalid, 1)
}
// isLoaded is atomic operation
func (ce *cEntry) isLoaded() bool {
return atomic.LoadUint64(&ce.loaded) != 0
}
// markLoaded is atomic operation
func (ce *cEntry) markLoaded() {
atomic.StoreUint64(&ce.loaded, 1)
}
type newEntryFunc func(ctx context.Context, k Key) *cEntry
type replaceEntryFunc func(ctx context.Context, k Key, old *cEntry) *cEntry
type cache struct {
CacheOption
mtx sync.RWMutex
store map[interface{}]*cEntry
reaper *time.Ticker
}
func NewMemCache(opts ...CacheOptions) *cache {
opt := CacheOption{
Heartbeat: 10 * time.Minute,
LoadRetry: 2,
}
for _, fn := range opts {
fn(&opt)
}
c := &cache{
CacheOption: opt,
store: map[interface{}]*cEntry{},
}
c.startReaper()
return c
}
func (c *cache) GetOrLoad(ctx context.Context, k Key, loader LoadFunc, validator ValidateFunc) (interface{}, error) {
if loader == nil {
return nil, fmt.Errorf("unable to load valid entry: LoadFunc is nil")
}
// maxRetry should be > 0, no upper limit
// 1. when entry exists and not expired/invalidated, no retry
// 2. when entry is newly created, no retry
// 3. when entry exists but expired/invalidated, mark it invalidated and retry
for i := 0; i <= c.LoadRetry; i++ {
// getOrNew guarantee that only one goroutine create new entry (if needed)
// aka, getOrNew uses cache-wise RW lock to ensure such behavior
entry, isNew := c.getOrNew(ctx, k, c.newEntryFunc(loader))
if entry == nil {
return nil, fmt.Errorf("[Internal Error] cache returns nil entry")
}
// wait for entry to load
entry.wg.Wait()
// from now on, entry content become immutable
// check entry validity
// note that we skip validation if the entry is freshly created
if isNew || !entry.isExpired() && (entry.lastErr != nil || validator == nil || validator(ctx, entry.value)) {
// valid entry
if entry.lastErr != nil {
return nil, entry.lastErr
}
return entry.value, nil
}
entry.invalidate()
}
return nil, fmt.Errorf("unable to load valid entry")
}
func (c *cache) Update(ctx context.Context, k Key, updater UpdateFunc) (bool, error) {
if updater == nil {
return false, fmt.Errorf("unable to update: UpdateFunc is nil")
}
_, ok := c.get(k)
if !ok {
return false, nil
}
newEntry := c.replaceIfPresent(ctx, k, c.updateEntryFunc(updater))
if newEntry == nil {
return false, nil
}
newEntry.wg.Wait()
if !newEntry.isExpired() {
return true, newEntry.lastErr
}
return true, nil
}
func (c *cache) Delete(k Key) {
c.set(k, nil)
}
func (c *cache) Reset() {
c.mtx.RLock()
defer c.mtx.RUnlock()
c.store = map[interface{}]*cEntry{}
}
func (c *cache) Evict() {
c.evict()
}
// newEntryFunc returns a newEntryFunc that create an entry and kick off "loader" in separate goroutine
// this method is not goroutine safe.
func (c *cache) newEntryFunc(loader LoadFunc) newEntryFunc {
return func(ctx context.Context, key Key) *cEntry {
ret := &cEntry{}
ret.wg.Add(1)
// schedule load
go c.load(ctx, key, ret, loader)
return ret
}
}
// updateEntryFunc returns a replaceEntryFunc that create an entry using given UpdateFunc and old entry
// this method is not goroutine safe.
// this method assume the old entry is not nil and already loaded
func (c *cache) updateEntryFunc(updater UpdateFunc) replaceEntryFunc {
return func(ctx context.Context, key Key, old *cEntry) *cEntry {
ret := &cEntry{
value: old.value,
expire: old.expire,
lastErr: old.lastErr,
}
ret.wg.Add(1)
// wrap updater as loader
go c.load(ctx, key, ret, func(ctx context.Context, k Key) (interface{}, time.Time, error) {
return updater(ctx, k, old.value)
})
return ret
}
}
// load execute given loader and sent entry's sync.WaitGroup Done()
// this method is not goroutine-safe and should be invoked only once
func (c *cache) load(ctx context.Context, key Key, entry *cEntry, loader LoadFunc) {
v, exp, e := loader(ctx, key)
entry.value = v
entry.expire = exp
entry.lastErr = e
entry.markLoaded()
entry.wg.Done()
}
// getOrNew return existing entry or create and set using newIfAbsent
// this method is goroutine-safe
func (c *cache) getOrNew(ctx context.Context, pKey Key, newIfAbsent newEntryFunc) (entry *cEntry, isNew bool) {
v, ok := c.get(pKey)
if ok {
return v, false
}
return c.newIfAbsent(ctx, pKey, newIfAbsent)
}
// newIfAbsent create entry using given "creator" if the key doesn't exist. otherwise returns existing entry
// this method is goroutine-safe
func (c *cache) newIfAbsent(ctx context.Context, pKey Key, creator newEntryFunc) (entry *cEntry, isNew bool) {
c.mtx.Lock()
defer c.mtx.Unlock()
if v, ok := c.getValue(pKey); ok && !v.isInvalidated() {
return v, false
}
v := creator(ctx, pKey)
c.setValue(pKey, v)
return v, true
}
// replaceIfPresent create entry using given "replacer" and replace the current entry if the key exists. otherwise returns nil
// this method is goroutine-safe
func (c *cache) replaceIfPresent(ctx context.Context, pKey Key, replacer replaceEntryFunc) (entry *cEntry) {
c.mtx.Lock()
defer c.mtx.Unlock()
existing, ok := c.getValue(pKey)
if !ok || existing.isInvalidated() {
return nil
}
existing.wg.Wait()
entry = replacer(ctx, pKey, existing)
c.setValue(pKey, entry)
return entry
}
// set is goroutine-safe
func (c *cache) set(pKey Key, v *cEntry) {
c.mtx.Lock()
defer c.mtx.Unlock()
c.setValue(pKey, v)
}
// get is goroutine-safe
func (c *cache) get(pKey Key) (*cEntry, bool) {
c.mtx.RLock()
defer c.mtx.RUnlock()
if v, ok := c.getValue(pKey); ok && !v.isInvalidated() {
return v, ok
}
return nil, false
}
// getValue not goroutine-safe
func (c *cache) getValue(pKey Key) (*cEntry, bool) {
k := reflect.Indirect(reflect.ValueOf(pKey.Hash())).Interface()
if v, ok := c.store[k]; ok && v != nil {
return v, true
}
return nil, false
}
// setValue not goroutine-safe
func (c *cache) setValue(pKey Key, v *cEntry) {
k := reflect.Indirect(reflect.ValueOf(pKey.Hash())).Interface()
if v == nil {
delete(c.store, k)
} else {
c.store[k] = v
c.deleteInvalidatedValues()
}
}
// deleteInvalidatedValues remove given keys
// this method is not goroutine-safe
func (c *cache) deleteInvalidatedValues() {
for k, v := range c.store {
if v.isInvalidated() {
delete(c.store, k)
}
}
}
func (c *cache) startReaper() {
c.reaper = time.NewTicker(c.Heartbeat)
go func() {
for {
select {
case <-c.reaper.C:
c.evict()
}
}
}()
}
func (c *cache) evict() {
// step 1, go through the store, find loaded entries (using atomic flag) and mark them invalidated if expired (with R lock)
func() {
c.mtx.RLock()
defer c.mtx.RUnlock()
for _, v := range c.store {
if !v.isInvalidated() && v.isLoaded() && v.isExpired() {
v.invalidate()
}
}
}()
// step 2, remove invalidated entries (with W lock)
c.mtx.Lock()
defer c.mtx.Unlock()
c.deleteInvalidatedValues()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"context"
)
// MutableContext wraps context.Context with an internal KV pairs storage.
// KV pairs stored in this context can be changed in later time.
// To change/list KV pairs on any context.Context that inherit from MutableContext,
// use FindMutableContext to obtain a MutableContextAccessor.
// See FindMutableContext and MutableContextAccessor for more details
type MutableContext interface {
context.Context
Set(key, value any)
}
// ListableContext is supplementary interface of MutableContext, listing all values stored in the context
type ListableContext interface {
context.Context
Values() map[interface{}]interface{}
}
// ContextValuer is an additional source of context.Context.Value(any)) used by MutableContext to search values with key.
// When MutableContext cannot find given key in its internal store, it will go through all ContextValuers
// before pass along the key-value searching to its parent context.
// See NewMutableContext and MakeMutableContext
type ContextValuer func(key interface{}) interface{}
// ckMutableContext is the key for itself
type mutableContextKey struct{}
var ckMutableContext = mutableContextKey{}
// mutableContext implements GinContext, ListableContext and MutableContext
type mutableContext struct {
context.Context
values map[interface{}]interface{}
valuers []ContextValuer
}
func (ctx *mutableContext) Value(key interface{}) (ret interface{}) {
switch key {
case ckMutableContext:
return ctx
}
// get value from value map first, in case the key-value pair is overwritten
ret, ok := ctx.values[key]
if ok && ret != nil {
return
}
// use valuers to get
for _, valuer := range ctx.valuers {
if ret = valuer(key); ret != nil {
return
}
}
// pass along to parent
return ctx.Context.Value(key)
}
func (ctx *mutableContext) Set(key any, value any) {
if key != nil && value != nil {
ctx.values[key] = value
} else if key != nil {
delete(ctx.values, key)
}
}
// Values recursively gather all KVs stored in MutableContext and its parent contexts.
// In case of overridden keys, the value of outermost context is used.
func (ctx *mutableContext) Values() (values map[interface{}]interface{}) {
hierarchy := make([]*mutableContext, 0, 5)
for mc := ctx; mc != nil; mc, _ = mc.Context.Value(ckMutableContext).(*mutableContext) {
hierarchy = append(hierarchy, mc)
}
// go over the inheritance hierarchy from root to current, in case the value is overridden
values = make(map[interface{}]interface{})
for i := len(hierarchy) - 1; i >= 0; i-- {
for k, v := range hierarchy[i].values {
values[k] = v
}
}
return values
}
// NewMutableContext Wrap given context.Context with a mutable store and optionally additional KV sources defined as ContextValuer
//nolint:contextcheck // false positive - Non-inherited new context, use function like `context.WithXXX` instead
func NewMutableContext(parent context.Context, valuers ...ContextValuer) MutableContext {
if parent == nil {
parent = context.Background()
}
return &mutableContext{
Context: parent,
values: make(map[interface{}]interface{}),
valuers: valuers,
}
}
// MakeMutableContext return the context itself if it's already a MutableContext and no additional ContextValuer are specified.
// Otherwise, wrap the given context as MutableContext.
// Note: If the given context itself is not a MutableContext but its hierarchy contains MutableContext as parent context,
// a new MutableContext is created and its mutable store (map) is not shared with the one from the hierarchy.
func MakeMutableContext(parent context.Context, valuers ...ContextValuer) MutableContext {
if mutable, ok := parent.(*mutableContext); ok && len(valuers) == 0 {
return mutable
}
return NewMutableContext(parent, valuers...)
}
type MutableContextAccessor interface {
Set(key, value any)
Values() (values map[any]any)
}
// FindMutableContext search for MutableContext from given context.Context's inheritance hierarchy,
// and return a MutableContextAccessor for key-values manipulation.
// If MutableContext is not found, nil is returned.
//
// Important: This function may returns parent context of the given one. Therefore, changing values may affect parent context.
func FindMutableContext(ctx context.Context) MutableContextAccessor {
if mc, ok := ctx.Value(ckMutableContext).(*mutableContext); ok {
return mc
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package cryptoutils
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"os"
"strings"
)
func LoadCert(file string) ([]*x509.Certificate, error) {
var result []*x509.Certificate
certFile, err := os.Open(file)
if err != nil {
return nil, err
}
certBytes, err := io.ReadAll(certFile)
if err != nil {
return nil, err
}
for block, r := pem.Decode(certBytes); block != nil; block, r = pem.Decode(r) {
var cert *x509.Certificate
switch {
case block.Type == "CERTIFICATE":
cert, err = x509.ParseCertificate(block.Bytes)
default:
continue
}
if err != nil {
return nil, err
}
result = append(result, cert)
}
return result, err
}
func LoadPrivateKey(file string, keyPassword string) (*rsa.PrivateKey, error) {
keyFile, err := os.Open(file)
if err != nil {
return nil, err
}
keyBytes, err := io.ReadAll(keyFile)
if err != nil {
return nil, err
}
keyBlock, _ := pem.Decode(keyBytes)
if keyPassword != "" {
//nolint:staticcheck // TODO find alternative
unEncryptedKey, err := x509.DecryptPEMBlock(keyBlock, []byte(keyPassword))
if err != nil {
return nil, err
}
key, err := x509.ParsePKCS1PrivateKey(unEncryptedKey)
return key, err
} else {
key, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
if err != nil {
return nil, err
}
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
return rsaKey, err
} else {
return nil, errors.New("private key is not rsa key")
}
}
}
// RandReader is the io.Reader that produces cryptographically random
// bytes when they are need by the library. The default value is
// rand.Reader, but it can be replaced for testing.
var RandReader = rand.Reader
func RandomBytes(n int) []byte {
rv := make([]byte, n)
if _, err := io.ReadFull(RandReader, rv); err != nil {
panic(err)
}
return rv
}
// LoadMultiBlockPem load items (cert, private key, public key, etc.) from pem file.
// Supported block types are
// - * PRIVATE KEY
// - PUBLIC KEY
// - CERTIFICATE
func LoadMultiBlockPem(path string, password string) ([]interface{}, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
data, err := io.ReadAll(f)
if err != nil {
return nil, err
}
result := []interface{}{}
for block, r := pem.Decode(data); block != nil; block, r = pem.Decode(r) {
var item interface{}
var e error
switch {
case strings.HasSuffix(block.Type, "PRIVATE KEY"):
item, e = parsePrivateKey(block, password)
case strings.HasSuffix(block.Type, "PUBLIC KEY"):
item, e = parsePublicKey(block)
case block.Type == "CERTIFICATE":
item, e = parseX509Cert(block)
default:
item = block
}
if e != nil {
return nil, e
}
result = append(result, item)
}
return result, nil
}
func parsePrivateKey(block *pem.Block, password string) (interface{}, error) {
data := block.Bytes
if password != "" {
//nolint:staticcheck // TODO find alternative
decrypted, e := x509.DecryptPEMBlock(block, []byte(password))
if e != nil {
return nil, e
}
data = decrypted
}
// try PKCS8 first
if key, e := x509.ParsePKCS8PrivateKey(data); e == nil {
return key, nil
}
// fallback to PKCS1
// this only handles RSA keys
if key, e := x509.ParsePKCS1PrivateKey(data); e == nil {
return key, nil
}
// this handles EC keys
return x509.ParseECPrivateKey(data)
}
func parsePublicKey(block *pem.Block) (interface{}, error) {
// try PKIX first (there's no pkcs8 for public keys because it's for private keys only)
if key, e := x509.ParsePKIXPublicKey(block.Bytes); e == nil {
return key, nil
}
// fallback to PKCS1
return x509.ParsePKCS1PublicKey(block.Bytes)
}
func parseX509Cert(block *pem.Block) (*x509.Certificate, error) {
return x509.ParseCertificate(block.Bytes)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package errorutils
import (
"encoding/gob"
"fmt"
"reflect"
)
var reserved = map[int64]string{}
func init() {
gob.Register((*CodedError)(nil))
gob.Register((*nestedError)(nil))
}
// Reserve is used for error hierarchy defining packages to validate and reserve its error code range
// it's usually called during init()
// this funciton does following things:
// 1. validate given err implements error, ErrorCoder and ComparableErrorCoder
// 2. the mask equals ReservedMask (a category error created via NewErrorCategory)
// 3. bits lower than ReservedOffset of the given error's code are all 0
// 4. if the code is available (not registered by other packages)
// 5. try to register the error's implementation with gob
func Reserve(err interface{}) {
switch err.(type) {
case error, ErrorCoder, ComparableErrorCoder:
default:
panic(fmt.Errorf("cannot reserve error category %T", err))
}
if masker := err.(ComparableErrorCoder); masker.CodeMask() != ReservedMask {
panic(fmt.Errorf("cannot reserve error category with code mask %x", masker.CodeMask()))
}
coder := err.(ErrorCoder)
if coder.Code() & ^ReservedMask != 0 {
panic(fmt.Errorf("cannot reserve error category with code %x, it's not a category level codes", coder.Code() ))
}
if pkg, ok := reserved[coder.Code()]; ok {
panic(fmt.Errorf("error category with code %x is already registered by ", pkg ))
}
// try reserve
gob.Register(err)
reserved[coder.Code()] = reflect.TypeOf(err).PkgPath()
}
type ErrorCoder interface {
Code() int64
}
type ComparableErrorCoder interface {
CodeMask() int64
}
type NestedError interface {
// Cause returns directly nested error
Cause() error
// RootCause returns the root cause of error, equivalent to calling Cause repeatedly
RootCause() error
}
type ComparableError interface {
Is(target error) bool
}
type Unwrapper interface {
Unwrap() error
}
// Hasher is an interface for error implementations that not naturally hashable to be used as a map key
// e.g. if a struct containing Slice, Map, Array are not hashable. Therefore, use those errors as map key would panic
// For implementations that need use error as map key (e.g. map[error]interface{}) should use Hasher.Hash as key
// Note: CodedError doesn't implement this interface because itself is hashable
type Hasher interface {
Hash() error
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package errorutils
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
)
// CodedError implements error, Code, CodeMask, NestedError, ComparableError, Unwrapper
// encoding.TextMarshaler, json.Marshaler, encoding.BinaryMarshaler, encoding.BinaryUnmarshaler
type CodedError struct {
ErrMsg string
ErrCode int64
ErrMask int64
Nested error
}
func (e CodedError) Error() string {
return e.ErrMsg
}
func (e CodedError) Code() int64 {
return e.ErrCode
}
func (e CodedError) CodeMask() int64 {
return e.ErrMask
}
func (e CodedError) Cause() error {
return e.Nested
}
func (e CodedError) RootCause() error {
//nolint:errorlint
if nested, ok := e.Nested.(NestedError); ok {
return nested.RootCause()
}
return e.Nested
}
// WithMessage make a concrete error with given error message
func (e CodedError) WithMessage(msg string, args ...interface{}) *CodedError {
return NewCodedError(e.ErrCode, fmt.Errorf(msg, args...))
}
// WithCause make a concrete error with given cause and error message
func (e CodedError) WithCause(cause error, msg string, args ...interface{}) *CodedError {
return NewCodedError(e.ErrCode, fmt.Errorf(msg, args...), cause)
}
// MarshalText implements encoding.TextMarshaler
func (e CodedError) MarshalText() ([]byte, error) {
return []byte(e.Error()), nil
}
// MarshalBinary implements encoding.BinaryMarshaler interface
// ErrCode, ErrMask, error.Error() are written into byte array in the mentioned order
// ErrCode and ErrMask are written as 64 bits with binary.BigEndian
// Note: currently we don't serialize Cause() to avoid cyclic reference
func (e CodedError) MarshalBinary() ([]byte, error) {
buffer := bytes.NewBuffer([]byte{})
if err := binary.Write(buffer, binary.BigEndian, e.ErrCode); err != nil {
return nil, err
}
if err := binary.Write(buffer, binary.BigEndian, e.ErrMask); err != nil {
return nil, err
}
if _, err := buffer.WriteString(e.Error()); err != nil {
return nil, err
}
if err := buffer.WriteByte(byte(0)); err != nil {
return nil, err
}
return buffer.Bytes(), nil
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler interface
func (e *CodedError) UnmarshalBinary(data []byte) error {
buffer := bytes.NewBuffer(data)
var code, mask int64
if err := binary.Read(buffer, binary.BigEndian, &code); err != nil {
return err
}
if err := binary.Read(buffer, binary.BigEndian, &mask); err != nil {
return err
}
errBytes, err := buffer.ReadBytes(byte(0))
if err != nil {
return err
}
e.ErrCode = code
e.ErrMask = mask
e.ErrMsg = string(errBytes[:len(errBytes)-1])
return nil
}
// Is return true if
// 1. target has same ErrCode, OR
// 2. target is a type/sub-type error and the receiver error is in same type/sub-type
//nolint:errorlint
func (e CodedError) Is(target error) bool {
compare := e.ErrCode
if masker, ok := target.(ComparableErrorCoder); ok {
compare = e.ErrCode & masker.CodeMask()
}
if coder, ok := target.(ErrorCoder); ok && compare == coder.Code() {
return true
}
return false
}
// nestedError implements error, NestedError
type nestedError struct {
error
nested error
}
func (e nestedError) Is(target error) bool {
return errors.Is(e.error, target) || e.nested != nil && errors.Is(e.nested, target)
}
func (e nestedError) Cause() error {
return e.nested
}
//nolint:errorlint
func (e nestedError) RootCause() error {
for root := e.nested; root != nil; {
if nested, ok := root.(NestedError); ok {
root = nested.Cause()
} else {
return root
}
}
return e.error
}
/************************
Constructors
*************************/
func newCodedError(code int64, msg string, mask int64, cause error) *CodedError {
return &CodedError{
ErrMsg: msg,
ErrCode: code,
ErrMask: mask,
Nested: cause,
}
}
func NewErrorCategory(code int64, e interface{}) *CodedError {
code = code & ReservedMask
return newCodedError(code, fmt.Sprintf("%v", e), ReservedMask, nil)
}
func NewErrorType(code int64, e interface{}) *CodedError {
code = code & ErrorTypeMask
return newCodedError(code, fmt.Sprintf("%v", e), ErrorTypeMask, nil)
}
func NewErrorSubType(code int64, e interface{}) *CodedError {
code = code & ErrorSubTypeMask
return newCodedError(code, fmt.Sprintf("%v", e), ErrorSubTypeMask, nil)
}
// construct error from supported item: string, error, fmt.Stringer
func construct(e interface{}) error {
var err error
switch e.(type) {
case error:
err = e.(error)
case fmt.Stringer:
err = errors.New(e.(fmt.Stringer).String())
case string:
err = errors.New(e.(string))
default:
err = fmt.Errorf("%v", e)
}
return err
}
// NewCodedError creates concrete error. it cannot be used as ErrorType or ErrorSubType comparison
// supported item are string, error, fmt.Stringer
func NewCodedError(code int64, e interface{}, causes ...interface{}) *CodedError {
causes = append([]interface{}{e}, causes...)
// chain causes
var cause error
for i := len(causes) - 1; i >= 0; i-- {
current := construct(causes[i])
if cause == nil {
cause = current
} else {
cause = &nestedError{
error: current,
nested: cause,
}
}
}
return newCodedError(code, fmt.Sprintf("%v", e), DefaultErrorCodeMask, cause)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package errorutils
import (
"errors"
"fmt"
)
// WrappedError is an embeddable struct that
// provide a convenient way to implement concrete error within certain error hierarchy without error code.
// This error implementation contains 3 components:
// - ErrIs is an anchor error used for comparison. Used for Is
// - Type is the parent error indicating its type, CodedError. Used for Unwrap
// - ErrMsg is the error's actual string value. Used for Error
type WrappedError struct {
ErrIs error
Type *CodedError
ErrMsg string
}
func (e WrappedError) Error() string {
return e.ErrMsg
}
func (e WrappedError) Is(target error) bool {
//nolint:errorlint // type assert is intentional
switch t := target.(type) {
case compareTargeter:
wrappedE := t.target()
return e == wrappedE || errors.Is(e.ErrIs, wrappedE.ErrIs) && errors.Is(e.Type, wrappedE.Type)
default:
return false
}
}
// Unwrap returns type error,
// which makes sure that errors.Is(e, errorType) returns true when errors.Is(e.Type, errorType) is true
func (e WrappedError) Unwrap() error {
return e.Type
}
// MarshalText implements encoding.TextMarshaler
func (e WrappedError) MarshalText() ([]byte, error) {
return []byte(e.Error()), nil
}
func (e WrappedError) WithMessage(msg string, args ...interface{}) WrappedError {
return WrappedError{
ErrIs: e.ErrIs,
Type: e.Type,
ErrMsg: fmt.Sprintf(msg, args...),
}
}
// compareTarget is an internal interface that makes Embedding implementation can be compared with another WrappedError
// overriding Is
type compareTargeter interface {
target() WrappedError
}
func (e WrappedError) target() WrappedError {
return e
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"fmt"
"reflect"
)
const (
errTmplIncorrectSignature = `incorrect signature [%T] for RecoverableFunc, check the document for usage`
)
var (
typeError = reflect.TypeOf((*error)(nil)).Elem()
)
// A SupportedRecoverableFunc is a function that can be converted by RecoverableFunc
type SupportedRecoverableFunc interface {
~func() | ~func() error
}
// RecoverableFunc wrap a panicing function with following signature
// - func()
// - func() error
// into a func() error, where the recovered value is converted to error
// This function panics if the given function has incorrect signature
func RecoverableFunc[T SupportedRecoverableFunc](panicingFunc T) func() error {
rv := reflect.ValueOf(panicingFunc)
rt := rv.Type()
// Note: with generic signature, following checks are unnecessary
//if rt.Kind() != reflect.Func {
// panic("unable to recover a non-function type")
//}
//if rt.NumIn() != 0 {
// panic(fmt.Sprintf(errTmplIncorrectSignature, panicingFunc))
//}
var fn func() error
switch rt.NumOut() {
case 0:
fn = func() error {
rv.Call(nil)
return nil
}
case 1:
if !rt.Out(0).AssignableTo(typeError) {
panic(fmt.Sprintf(errTmplIncorrectSignature, panicingFunc))
}
fn = func() error {
ret := rv.Call(nil)
if ret[0].IsNil() {
return nil
}
return ret[0].Interface().(error)
}
default:
panic(fmt.Sprintf(errTmplIncorrectSignature, panicingFunc))
}
return func() (err error) {
defer func() {
switch v := recover().(type) {
case error:
err = v
case nil:
default:
err = fmt.Errorf("%v", v)
}
}()
return fn()
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package loop
import (
"context"
"fmt"
"sync"
"time"
)
// TaskFunc can be scheduled to Loop with TaskOptions
type TaskFunc func(ctx context.Context, l *Loop) (ret interface{}, err error)
// RepeatIntervalFunc is used when schedule repeated TaskFunc.
// it takes result and error of previous TaskFunc invocation and determine the delay of next TaskFunc invocation
type RepeatIntervalFunc func(result interface{}, err error) time.Duration
type TaskOptions func(opt *TaskOption)
type TaskOption struct {
RepeatIntervalFunc RepeatIntervalFunc
}
type taskResult struct {
ret interface{}
err error
}
type task struct {
f TaskFunc
resultCh chan taskResult
opt TaskOption
}
type Loop struct {
taskCh chan *task
mtx sync.Mutex
ctx context.Context
cancelFn context.CancelFunc
}
func NewLoop() *Loop {
return &Loop{
taskCh: make(chan *task),
}
}
func (l *Loop) Run(ctx context.Context) (context.Context, context.CancelFunc) {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.ctx == nil {
l.ctx, l.cancelFn = context.WithCancel(ctx)
go l.loop(l.ctx)
}
return l.ctx, l.cancelFn
}
func (l *Loop) Repeat(tf TaskFunc, opts ...TaskOptions) {
opt := TaskOption{
RepeatIntervalFunc: fixedRepeatIntervalFunc(10 * time.Millisecond),
}
for _, f := range opts {
f(&opt)
}
l.taskCh <- &task{
f: l.makeTaskFuncWithRepeat(tf, opt.RepeatIntervalFunc),
opt: opt,
}
}
func (l *Loop) Do(tf TaskFunc, opts ...TaskOptions) {
opt := TaskOption{}
for _, f := range opts {
f(&opt)
}
l.taskCh <- &task{
f: tf,
opt: opt,
}
}
func (l *Loop) DoAndWait(tf TaskFunc, opts ...TaskOptions) (interface{}, error) {
opt := TaskOption{}
for _, f := range opts {
f(&opt)
}
resultCh := make(chan taskResult)
defer close(resultCh)
l.taskCh <- &task{
f: tf,
resultCh: resultCh,
opt: opt,
}
select {
case result := <-resultCh:
return result.ret, result.err
}
}
func (l *Loop) loop(ctx context.Context) {
for {
select {
case t := <-l.taskCh:
l.do(ctx, t)
case <-ctx.Done():
return
}
}
}
func (l *Loop) do(ctx context.Context, t *task) {
// we assume the cancel signal is propagated from parent
execCtx, doneFn := context.WithCancel(ctx)
// we guarantee that either resultCh is pushed with value, so we don't need to explicitly close those channels here
go func() {
defer func() {
if e := recover(); e != nil && t.resultCh != nil {
t.resultCh <- taskResult{err: fmt.Errorf("%v", e)}
}
doneFn()
}()
r, e := t.f(execCtx, l)
if t.resultCh != nil {
// check if parent ctx is cancelled
select {
case <-ctx.Done():
t.resultCh <- taskResult{err: ctx.Err()}
default:
t.resultCh <- taskResult{
ret: r,
err: e,
}
}
}
}()
// wait for finish or cancelled
select {
case <-execCtx.Done():
}
}
// makeTaskFuncWithRepeat make a func that execute given TaskFunc and reschedule itself after given "interval"
func (l *Loop) makeTaskFuncWithRepeat(tf TaskFunc, intervalFunc RepeatIntervalFunc) TaskFunc {
return func(ctx context.Context, l *Loop) (ret interface{}, err error) {
// reschedule after delayed time
defer func() {
interval := intervalFunc(ret, err)
l.repeatAfter(l.makeTaskFuncWithRepeat(tf, intervalFunc), interval)
}()
ret, err = tf(ctx, l)
return
}
}
func (l *Loop) repeatAfter(tf TaskFunc, interval time.Duration) {
go func() {
timer := time.NewTimer(interval)
select {
case <-timer.C:
l.Do(tf)
case <-l.ctx.Done():
timer.Stop()
}
}()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package loop
import (
"time"
)
// FixedRepeatInterval returns a TaskOptions which set repeat interval to be fixed duration
func FixedRepeatInterval(interval time.Duration) TaskOptions {
return func(opt *TaskOption) {
opt.RepeatIntervalFunc = fixedRepeatIntervalFunc(interval)
}
}
func fixedRepeatIntervalFunc(interval time.Duration) RepeatIntervalFunc {
return func(_ interface{}, _ error) time.Duration {
return interval
}
}
// ExponentialRepeatIntervalOnError returns a TaskOptions
// which set repeat interval to be exponentially increased if error is not nil.
// the repeat interval is reset to "init" if error is nil
func ExponentialRepeatIntervalOnError(init time.Duration, factor float64) TaskOptions {
if factor < 1 {
panic("attempt to use ExponentialRepeatIntervalOnError with a factor less than 1")
}
return func(opt *TaskOption) {
opt.RepeatIntervalFunc = exponentialRepeatIntervalOnErrorFunc(init, factor)
}
}
func exponentialRepeatIntervalOnErrorFunc(init time.Duration, factor float64) RepeatIntervalFunc {
curr := init
return func(_ interface{}, err error) time.Duration {
if err == nil {
curr = init
} else {
curr = time.Duration(float64(curr) * factor)
}
return curr
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package matcher
import (
"context"
"fmt"
"strings"
)
type Matcher interface {
Matches(interface{}) (bool, error)
MatchesWithContext(context.Context, interface{}) (bool, error)
}
type ChainableMatcher interface {
Matcher
// Or concat given matchers with OR operator
Or(matcher ...Matcher) ChainableMatcher
// And concat given matchers with AND operator
And(matcher ...Matcher) ChainableMatcher
}
// Any returns a matcher that matches everything
func Any() ChainableMatcher {
return NoopMatcher(true)
}
// None returns a matcher that matches nothing
func None() ChainableMatcher {
return NoopMatcher(false)
}
// Or concat given matchers with OR operator
func Or(left Matcher, right...Matcher) ChainableMatcher {
return OrMatcher(append([]Matcher{left}, right...))
}
// And concat given matchers with AND operator
func And(left Matcher, right...Matcher) ChainableMatcher {
return AndMatcher(append([]Matcher{left}, right...))
}
// Not returns a negated matcher
func Not(matcher Matcher) ChainableMatcher {
return &NegateMatcher{matcher}
}
// NoopMatcher matches stuff literally
type NoopMatcher bool
func (m NoopMatcher) Matches(_ interface{}) (bool, error) {
return bool(m), nil
}
func (m NoopMatcher) MatchesWithContext(context.Context, interface{}) (bool, error) {
return bool(m), nil
}
func (m NoopMatcher) Or(matchers ...Matcher) ChainableMatcher {
return Or(m, matchers...)
}
func (m NoopMatcher) And(matchers ...Matcher) ChainableMatcher {
return And(m, matchers...)
}
func (m NoopMatcher) String() string {
if m {
return "matches any"
} else {
return "matches none"
}
}
// OrMatcher chain a list of matchers with OR operator
type OrMatcher []Matcher
func (m OrMatcher) Matches(i interface{}) (ret bool, err error) {
for _,item := range m {
if ret,err = item.Matches(i); ret || err != nil {
break
}
}
return
}
func (m OrMatcher) MatchesWithContext(c context.Context, i interface{}) (ret bool, err error) {
for _,item := range m {
if ret,err = item.MatchesWithContext(c, i); ret || err != nil {
break
}
}
return
}
func (m OrMatcher) Or(matchers ...Matcher) ChainableMatcher {
return Or(m, matchers...)
}
func (m OrMatcher) And(matchers ...Matcher) ChainableMatcher {
return And(m, matchers...)
}
func (m OrMatcher) String() string {
descs := make([]string, len(m))
for i,item := range m {
descs[i] = item.(fmt.Stringer).String()
}
return strings.Join(descs, " OR ")
}
// AndMatcher chain a list of matchers with AND operator
type AndMatcher []Matcher
func (m AndMatcher) Matches(i interface{}) (ret bool, err error) {
for _,item := range m {
if ret,err = item.Matches(i); !ret || err != nil {
break
}
}
return
}
func (m AndMatcher) MatchesWithContext(c context.Context, i interface{}) (ret bool, err error) {
for _,item := range m {
if ret,err = item.MatchesWithContext(c, i); !ret || err != nil {
break
}
}
return
}
func (m AndMatcher) Or(matchers ...Matcher) ChainableMatcher {
return Or(m, matchers...)
}
func (m AndMatcher) And(matchers ...Matcher) ChainableMatcher {
return And(m, matchers...)
}
func (m AndMatcher) String() string {
descs := make([]string, len(m))
for i,item := range m {
descs[i] = item.(fmt.Stringer).String()
}
return strings.Join(descs, " AND ")
}
// NegateMatcher apply ! operator to embedded Matcher
type NegateMatcher struct {
Matcher
}
func (m *NegateMatcher) Matches(i interface{}) (ret bool, err error) {
ret, err = m.Matcher.Matches(i)
return !ret, err
}
func (m NegateMatcher) MatchesWithContext(c context.Context, i interface{}) (ret bool, err error) {
ret, err = m.Matcher.MatchesWithContext(c, i)
return !ret, err
}
func (m *NegateMatcher) Or(matchers ...Matcher) ChainableMatcher {
return Or(m, matchers...)
}
func (m *NegateMatcher) And(matchers ...Matcher) ChainableMatcher {
return And(m, matchers...)
}
func (m NegateMatcher) String() string {
return fmt.Sprintf("Not(%v)", m.Matcher)
}
// GenericMatcher implements ChainableMatcher
// TODO review use cases to determine if this class is necessary
type GenericMatcher struct {
matchFunc func(context.Context, interface{}) (bool, error)
}
func (m *GenericMatcher) Matches(i interface{}) (bool, error) {
return m.matchFunc(context.TODO(), i)
}
func (m *GenericMatcher) MatchesWithContext(c context.Context, i interface{}) (ret bool, err error) {
return m.matchFunc(c, i)
}
func (m *GenericMatcher) Or(matchers ...Matcher) ChainableMatcher {
return Or(m, matchers...)
}
func (m *GenericMatcher) And(matchers ...Matcher) ChainableMatcher {
return And(m, matchers...)
}
func (m *GenericMatcher) String() string {
return fmt.Sprintf("generic matcher with func [%T]", m.matchFunc)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package matcher
import (
"context"
"fmt"
"github.com/bmatcuk/doublestar/v4"
"regexp"
"strings"
)
const (
descSuffixCaseInsensitive = `, case insensitive`
)
// StringMatcher is a typed ChainableMatcher that accept String
type StringMatcher interface {
ChainableMatcher
}
// stringMatcher implements ChainableMatcher, StringMatcher and only accept String
type stringMatcher struct {
description string
matchFunc func(context.Context, string) (bool, error)
}
func (m *stringMatcher) StringMatches(c context.Context, value string) (bool, error) {
return m.matchFunc(c, value)
}
func (m *stringMatcher) Matches(i interface{}) (bool, error) {
v, ok := i.(string)
if !ok {
return false, fmt.Errorf("StringMatcher doesn't support %T", i)
}
return m.StringMatches(context.TODO(), v)
}
func (m *stringMatcher) MatchesWithContext(c context.Context, i interface{}) (bool, error) {
v, ok := i.(string)
if !ok {
return false, fmt.Errorf("StringMatcher doesn't support %T", i)
}
return m.StringMatches(c, v)
}
func (m *stringMatcher) Or(matchers ...Matcher) ChainableMatcher {
return Or(m, matchers...)
}
func (m *stringMatcher) And(matchers ...Matcher) ChainableMatcher {
return And(m, matchers...)
}
func (m *stringMatcher) String() string {
return m.description
}
/**************************
Constructors
***************************/
func WithString(expected string, caseInsensitive bool) StringMatcher {
desc := fmt.Sprintf("matches [%s]", expected)
if caseInsensitive {
desc = desc + descSuffixCaseInsensitive
}
return &stringMatcher{
matchFunc: func(_ context.Context, value string) (bool, error) {
return MatchString(expected, value, caseInsensitive), nil
},
description: desc,
}
}
func WithSubString(substr string, caseInsensitive bool) StringMatcher {
desc := fmt.Sprintf("contains [%s]", substr)
if caseInsensitive {
desc = desc + descSuffixCaseInsensitive
}
return &stringMatcher{
matchFunc: func(_ context.Context, value string) (bool, error) {
return MatchSubString(substr, value, caseInsensitive), nil
},
description: desc,
}
}
func AnyNonEmptyString() StringMatcher {
desc := fmt.Sprintf("matches any non-empty string")
return &stringMatcher{
matchFunc: func(_ context.Context, value string) (bool, error) {
return value != "", nil
},
description: desc,
}
}
func WithPathPattern(pattern string) StringMatcher {
return &stringMatcher{
matchFunc: func(_ context.Context, value string) (bool, error) {
return MatchPathPattern(pattern, value)
},
description: fmt.Sprintf("matches pattern [%s]", pattern),
}
}
func WithPrefix(prefix string, caseInsensitive bool) StringMatcher {
desc := fmt.Sprintf("start with [%s]", prefix)
if caseInsensitive {
desc = desc + descSuffixCaseInsensitive
}
return &stringMatcher{
matchFunc: func(_ context.Context, value string) (bool, error) {
return MatchPrefix(prefix, value, caseInsensitive)
},
description: desc,
}
}
func WithSuffix(suffix string, caseInsensitive bool) StringMatcher {
desc := fmt.Sprintf("ends with [%s]", suffix)
if caseInsensitive {
desc = desc + descSuffixCaseInsensitive
}
return &stringMatcher{
matchFunc: func(_ context.Context, value string) (bool, error) {
return MatchSuffix(suffix, value, caseInsensitive)
},
description: desc,
}
}
func WithRegex(regex string) StringMatcher {
return &stringMatcher{
matchFunc: func(_ context.Context, value string) (bool, error) {
return MatchRegex(regex, value)
},
description: fmt.Sprintf("matches regex [%s]", regex),
}
}
func WithRegexPattern(regex *regexp.Regexp) StringMatcher {
return &stringMatcher{
matchFunc: func(_ context.Context, value string) (bool, error) {
return MatchRegexPattern(regex, value)
},
description: fmt.Sprintf("matches regex [%s]", regex.String()),
}
}
/**************************
helpers
***************************/
func MatchString(expected, actual string, caseInsensitive bool) bool {
if caseInsensitive {
expected = strings.ToLower(expected)
actual = strings.ToLower(actual)
}
return expected == actual
}
func MatchSubString(substr, actual string, caseInsensitive bool) bool {
if caseInsensitive {
substr = strings.ToLower(substr)
actual = strings.ToLower(actual)
}
return strings.Contains(actual, substr)
}
// MatchPathPattern given string with path pattern
// The prefix syntax is:
//
// prefix:
// { term }
// term:
// '*' matches any sequence of non-path-separators
// '**' matches any sequence of characters, including
// path separators.
// '?' matches any single non-path-separator character
// '[' [ '^' ] { character-range } ']'
// character class (must be non-empty)
// '{' { term } [ ',' { term } ... ] '}'
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
func MatchPathPattern(pattern, path string) (bool, error) {
if pattern == "" {
return true, nil
}
return doublestar.Match(pattern, path)
}
func MatchPrefix(prefix, value string, caseInsensitive bool) (bool, error) {
if caseInsensitive {
return strings.HasPrefix(strings.ToLower(value), strings.ToLower(prefix)), nil
}
return strings.HasPrefix(value, prefix), nil
}
func MatchSuffix(suffix, value string, caseInsensitive bool) (bool, error) {
if caseInsensitive {
return strings.HasSuffix(strings.ToLower(value), strings.ToLower(suffix)), nil
}
return strings.HasSuffix(value, suffix), nil
}
func MatchRegex(regex, value string) (bool, error) {
return regexp.MatchString(regex, value)
}
func MatchRegexPattern(regex *regexp.Regexp, value string) (bool, error) {
return regex.MatchString(value), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package netutil
import (
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
)
func GetIp(iface string) (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
var ip net.IP
for _, i := range ifaces {
name := i.Name
// Generally we don't want the utun interface, because on mac this is the vpn or "Back to My Mac" interface.
// However, if the user specifically asked for this interface, we will honor that.
// After we get the ip address, we will break the loop early if iface == name
// Otherwise, we use the ip of the last interface we processed (which is not utun)
if iface != name && strings.Contains(name, "utun") {
continue
}
addrs, e := i.Addrs()
if e != nil {
return "", e
}
for _, addr := range addrs {
switch v := addr.(type) {
//SuppressWarnings go:S1871 type switching, not duplicate
case *net.IPNet:
if v.IP.To4() != nil {
ip = v.IP
}
case *net.IPAddr:
if v.IP.To4() != nil {
ip = v.IP
}
}
}
if iface == name {
break
}
}
if ip == nil {
if iface == "" {
return "", errors.New("No valid interface or address found")
} else {
return "", errors.New(fmt.Sprintf("Interface %s not found or no address", iface))
}
}
return ip.String(), nil
}
func GetForwardedHostName(request *http.Request) string {
var host string
fwdAddress := request.Header.Get("X-Forwarded-Host") // capitalisation doesn't matter
if fwdAddress != "" {
ips := strings.Split(fwdAddress, ",")
orig := strings.TrimSpace(ips[0])
reqHost, _, err := net.SplitHostPort(orig)
if err == nil {
host = reqHost
} else {
host = orig
}
} else {
reqHost, _, err := net.SplitHostPort(request.Host)
if err == nil {
host = reqHost
} else {
host = request.Host
}
}
return host
}
func AppendRedirectUrl(redirectUrl string, params map[string]string) (string, error) {
loc, e := url.ParseRequestURI(redirectUrl)
if e != nil || !loc.IsAbs() {
return "", errors.New("invalid redirect_uri")
}
// TODO support fragments
query := loc.Query()
for k, v := range params {
query.Add(k, v)
}
loc.RawQuery = query.Encode()
return loc.String(), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package order
import (
"fmt"
"reflect"
"sort"
)
const (
Lowest = int(^uint(0) >> 1) // max int
Highest = -Lowest - 1 // min int
)
type Ordered interface {
Order() int
}
type PriorityOrdered interface {
PriorityOrder() int
}
// LessFunc is accepted less func by sort.Slice and sort.SliceStable
type LessFunc func(i, j int) bool
// CompareFunc is used to compare two interface's order,
type CompareFunc func(left interface{}, right interface{}) bool
// Sort wraps sort.Slice with LessFunc constructed from given CompareFunc using reflect
// function panic if given interface is not slice
func Sort(slice interface{}, compareFunc CompareFunc) {
sv := reflect.ValueOf(slice)
if sv.Kind() != reflect.Slice {
panic(fmt.Errorf("Sort only support slice, but got %T", slice))
}
sort.Slice(slice, func(i, j int) bool {
return compareFunc(sv.Index(i).Interface(), sv.Index(j).Interface())
})
}
// SortStable wraps sort.SliceStable with LessFunc constructed from given CompareFunc using reflect
// function panic if given interface is not slice
func SortStable(slice interface{}, compareFunc CompareFunc) {
sv := reflect.ValueOf(slice)
if sv.Kind() != reflect.Slice {
panic(fmt.Errorf("Sort only support slice, but got %T", slice))
}
sort.SliceStable(slice, func(i, j int) bool {
return compareFunc(sv.Index(i).Interface(), sv.Index(j).Interface())
})
}
// OrderedFirstCompare compares objects based on order interfaces with following rule
// - PriorityOrdered wins over other types
// - Ordered wins over non- PriorityOrdered
// - Same category will compare its corresponding order value
func OrderedFirstCompare(left interface{}, right interface{}) bool {
// first consider PriorityOrder
lp, lpok := left.(PriorityOrdered)
rp, rpok := right.(PriorityOrdered)
lo, look := left.(Ordered)
ro, rook := right.(Ordered)
switch {
// PriorityOrdered cases
case lpok && rpok:
return lp.PriorityOrder() < rp.PriorityOrder()
case lpok && !rpok:
return true
case !lpok && rpok:
return false
// Ordered cases
case look && rook:
return lo.Order() < ro.Order()
case look && !rook:
return true
case !look && rook:
return false
// not Ordered nor PriorityOrdered
default:
return false
}
}
// OrderedFirstCompareReverse compares objects based on order interfaces with same rule as OrderedFirstCompare but reversed
func OrderedFirstCompareReverse(left interface{}, right interface{}) bool {
// first consider PriorityOrder
lp, lpok := left.(PriorityOrdered)
rp, rpok := right.(PriorityOrdered)
lo, look := left.(Ordered)
ro, rook := right.(Ordered)
switch {
// PriorityOrdered cases
case lpok && rpok:
return lp.PriorityOrder() > rp.PriorityOrder()
case lpok && !rpok:
return false
case !lpok && rpok:
return true
// Ordered cases
case look && rook:
return lo.Order() > ro.Order()
case look && !rook:
return false
case !look && rook:
return true
// not Ordered nor PriorityOrdered
default:
return false
}
}
// OrderedLastCompare compares objects based on order interfaces with following rule
// - Regular object (neither PriorityOrdered nor Ordered) wins over other types
// - PriorityOrdered wins over Ordered
// - Same category will compare its corresponding order value
func OrderedLastCompare(left interface{}, right interface{}) bool {
// first consider PriorityOrder
lp, lpok := left.(PriorityOrdered)
rp, rpok := right.(PriorityOrdered)
lo, look := left.(Ordered)
ro, rook := right.(Ordered)
switch {
// if both side are regular objects, there's no order
case !lpok && !look && !rpok && !rook:
return false
// from here down, at least one side is not a regular object
// left or right are regular object
case !lpok && !look: //left is regular object
return true
case !rpok && !rook: //right is regular object
return false
// from here down, both side are ordered or priority ordered
// PriorityOrdered cases
case lpok && rpok:
return lp.PriorityOrder() < rp.PriorityOrder()
case lpok && !rpok:
return true
case !lpok && rpok:
return false
// Ordered cases
case look && rook:
return lo.Order() < ro.Order()
default:
return false // theoretically wouldn't get here
}
}
// OrderedLastCompareReverse compares objects based on order interfaces with same rule as OrderedLastCompare but reversed
func OrderedLastCompareReverse(left interface{}, right interface{}) bool {
// first consider PriorityOrder
lp, lpok := left.(PriorityOrdered)
rp, rpok := right.(PriorityOrdered)
lo, look := left.(Ordered)
ro, rook := right.(Ordered)
switch {
// if both side are regular objects, there's no order
case !lpok && !look && !rpok && !rook:
return false
// from here down, at least one side is not a regular object
// left or right are regular object
case !lpok && !look: //left is regular object
return false
case !rpok && !rook: //right is regular object
return true
// from here down, both side are ordered or priority ordered
// PriorityOrdered cases
case lpok && rpok:
return lp.PriorityOrder() > rp.PriorityOrder()
case lpok && !rpok:
return false
case !lpok && rpok:
return true
// Ordered cases
case look && rook:
return lo.Order() > ro.Order()
default:
return false //theoretically wouldn't get here
}
}
// UnorderedMiddleCompare compares objects based on order interfaces with following rule
// - PriorityOrdered wins over other types
// - Regular object (neither PriorityOrdered nor Ordered) wins Ordered
// - Ordered at last
// - Same category will compare its corresponding order value
func UnorderedMiddleCompare(left interface{}, right interface{}) bool {
// first consider PriorityOrder
lp, lpok := left.(PriorityOrdered)
rp, rpok := right.(PriorityOrdered)
lo, look := left.(Ordered)
ro, rook := right.(Ordered)
switch {
// PriorityOrdered cases - if at least one of the operand is PriorityOrdered
case lpok && rpok:
return lp.PriorityOrder() < rp.PriorityOrder()
case lpok && !rpok:
return true
case !lpok && rpok:
return false
// Both are unordered - at this point, we know they are not PriorityOrdered, so we just need to check both are also not Ordered
case !look && !rook:
return false // return false to indicate left is not less than right, so that the natural order is kept
// Left operand is not ordered, right operand is ordered
// return true so that un ordered comes before ordered
case !look:
return true
// Right operand is not ordered, left operand is ordered
// return false so that un ordered comes before ordered
case !rook:
return false
// both side are ordered
default:
return lo.Order() < ro.Order()
}
}
// UnorderedMiddleCompareReverse compares objects based on order interfaces with same rule as UnorderedMiddleCompare but reversed
func UnorderedMiddleCompareReverse(left interface{}, right interface{}) bool {
return !UnorderedMiddleCompare(left, right)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package utils
import (
cryptorand "crypto/rand"
"math/big"
"math/rand"
"time"
)
const (
CharsetAlphanumeric RandomCharset = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
CharsetAlphabetic RandomCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
//nolint:gosec // this is used as fallback, better than not working
var pseudoRand = rand.New(rand.NewSource(time.Now().UnixNano()))
// RandomCharset is a string containing all acceptable UTF-8 characters for random string generation
type RandomCharset string
// RandomString returns a random Alphanumeric string of given "length"
// this function uses "crypto/rand" and fallback to "math/rand"
// It panics if len(charset) > 255, and returns empty string if length is non-positive
func RandomString(length int) string {
return RandomStringWithCharset(length, CharsetAlphanumeric)
}
// RandomStringWithCharset returns a random string of given "length" containing only characters from given "charset"
// this function uses "crypto/rand" and fallback to "math/rand"
// It returns empty string if length is non-positive, and only the first 256 chars in "charset" are used
func RandomStringWithCharset(length int, charset RandomCharset) string {
if length <= 0 {
return ""
}
data := make([]byte, length)
b := make([]byte, 1)
for i := range data {
if n, e := cryptorand.Reader.Read(b); e != nil || n < 1 {
data[i] = charset[pseudoRand.Intn(len(charset))]
} else {
data[i] = charset[int(b[0]) % len(charset)]
}
}
return string(data)
}
// RandomInt64N returns, as an int64, a non-negative uniform number in the half-open interval [0,n).
// This function uses "crypto/rand" and fallback to "math/rand".
// It panics if n <= 0.
func RandomInt64N(n int64) int64 {
bigInt, e := cryptorand.Int(cryptorand.Reader, big.NewInt(n))
if e != nil {
return pseudoRand.Int63n(n) //nolint:gosec // this is fallback method, better than not working
}
return bigInt.Int64()
}
// RandomIntN returns, as an int64, a non-negative uniform number in the half-open interval [0,n).
// This function uses "crypto/rand" and fallback to "math/rand".
// It panics if n <= 0.
func RandomIntN(n int) int {
return int(RandomInt64N(int64(n)))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package reflectutils
import (
"reflect"
"unicode"
)
func IsExportedField(f reflect.StructField) bool {
if len(f.Name) == 0 {
return false
}
r := rune(f.Name[0])
return unicode.IsUpper(r)
}
// FindStructField recursively find field that matching the given matcher, including embedded fields
func FindStructField(sType reflect.Type, matcher func(t reflect.StructField) bool) (ret reflect.StructField, found bool) {
// dereference pointers and check type
t := sType
for ; t.Kind() == reflect.Ptr; t = t.Elem() {
// SuppressWarnings go:S108 empty block is intended
}
if t.Kind() != reflect.Struct {
return ret, false
}
// go through fields
for i := t.NumField() - 1; i >=0; i-- {
f := t.Field(i)
if ok := matcher(f); ok {
return f, true
}
if f.Anonymous {
// inspect embedded fields
if sub, ok := FindStructField(f.Type, matcher); ok {
sub.Index = append(f.Index, sub.Index...)
return sub, true
}
}
}
return
}
// ListStructField recursively find all fields that matching the given matcher, including embedded fields
func ListStructField(sType reflect.Type, matcher func(t reflect.StructField) bool) (ret []reflect.StructField) {
// dereference pointers and check type
t := sType
for ; t.Kind() == reflect.Ptr; t = t.Elem() {
// SuppressWarnings go:S108 empty block is intended
}
if t.Kind() != reflect.Struct {
return
}
// go through fields
for i := t.NumField() - 1; i >=0; i-- {
f := t.Field(i)
if ok := matcher(f); ok {
ret = append(ret, f)
}
if f.Anonymous {
// inspect embedded fields
if sub := ListStructField(f.Type, matcher); len(sub) != 0 {
// correct index path
for i := range sub {
sub[i].Index = append(f.Index, sub[i].Index...)
}
ret = append(ret, sub...)
}
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"encoding/json"
"fmt"
)
type void struct{}
/** StringSet **/
type StringSet map[string]void
func NewStringSet(values...string) StringSet {
return make(StringSet).Add(values...)
}
func NewStringSetFromSet(set Set) StringSet {
stringSet := make(StringSet)
for k, _ := range set {
if str, ok := k.(string); ok {
stringSet[str] = void{}
}
}
return stringSet
}
func NewStringSetFrom(i interface{}) StringSet {
switch i.(type) {
case StringSet:
return i.(StringSet).Copy()
case Set:
return NewStringSetFromSet(i.(Set))
case []string:
return NewStringSet(i.([]string)...)
case []interface{}:
slice := []string{}
for _,v := range i.([]interface{}) {
if s,ok := v.(string); ok {
slice = append(slice, s)
}
}
return NewStringSet(slice...)
default:
panic(fmt.Errorf("new StringSet from unsupported type %T", i))
}
}
func (s StringSet) Add(values...string) StringSet {
for _, item := range values {
s[item] = void{}
}
return s
}
func (s StringSet) Remove(values...string) StringSet {
for _, item := range values {
delete(s, item)
}
return s
}
func (s StringSet) Has(value string) bool {
_, ok := s[value]
return ok
}
func (s StringSet) HasAll(values ...string) bool {
for _, v := range values {
if !s.Has(v) {
return false
}
}
return true
}
func (s StringSet) Equals(another StringSet) bool {
if len(s) != len(another){
return false
} else if len(s) == 0 && len(another) == 0 {
return true
}
for k := range another {
if !s.Has(k) {
return false
}
}
return true
}
func (s StringSet) Values() []string {
values := make([]string, len(s))
var i int
for item := range s {
values[i] = item
i++
}
return values
}
func (s StringSet) Copy() StringSet {
cp := NewStringSet()
for k,_ := range s {
cp[k] = void{}
}
return cp
}
func (s StringSet) ToSet() Set {
return NewSetFromStringSet(s)
}
// MarshalJSON json.Marshaler
func (s StringSet) MarshalJSON() ([]byte, error) {
return json.Marshal(s.Values())
}
// UnmarshalJSON json.Unmarshaler
func (s *StringSet) UnmarshalJSON(data []byte) error {
values := make([]string, 0)
if err := json.Unmarshal(data, &values); err != nil {
return err
}
*s = NewStringSet(values...)
return nil
}
/** Interface Set **/
type Set map[interface{}]void
func NewSet(values...interface{}) Set {
return make(Set).Add(values...)
}
func NewSetFromStringSet(stringSet StringSet) Set {
set := NewSet()
for k := range stringSet {
set[k] = void{}
}
return set
}
func NewSetFrom(i interface{}) Set {
switch i.(type) {
case StringSet:
return NewSetFromStringSet(i.(StringSet))
case Set:
return i.(Set).Copy()
case []string:
slice := make([]interface{}, 0)
for _,v := range i.([]string) {
slice = append(slice, v)
}
return NewSet(slice...)
case []interface{}:
return NewSet(i.([]interface{})...)
default:
panic(fmt.Errorf("new StringSet from unsupported type %T", i))
}
}
func (s Set) Add(values...interface{}) Set {
for _, item := range values {
s[item] = void{}
}
return s
}
func (s Set) Remove(values...interface{}) Set {
for _, item := range values {
delete(s, item)
}
return s
}
func (s Set) Has(value interface{}) bool {
_, ok := s[value]
return ok
}
func (s Set) HasAll(values ...interface{}) bool {
for _, v := range values {
if !s.Has(v) {
return false
}
}
return true
}
func (s Set) Equals(another Set) bool {
if len(s) != len(another){
return false
} else if len(s) == 0 && len(another) == 0 {
return true
}
for k := range another {
if !s.Has(k) {
return false
}
}
return true
}
func (s Set) Values() []interface{} {
values := make([]interface{}, len(s))
var i int
for item := range s {
values[i] = item
i++
}
return values
}
func (s Set) Copy() Set {
cp := NewSet()
for k := range s {
cp[k] = void{}
}
return cp
}
// MarshalJSON json.Marshaler
func (s Set) MarshalJSON() ([]byte, error) {
return json.Marshal(s.Values())
}
// UnmarshalJSON json.Unmarshaler
func (s *Set) UnmarshalJSON(data []byte) error {
values := make([]interface{}, 0)
if err := json.Unmarshal(data, &values); err != nil {
return err
}
*s = NewSet(values...)
return nil
}
/** Generic Set **/
type GenericSet[T comparable] map[T]void
func NewGenericSet[T comparable](values...T) GenericSet[T] {
return make(GenericSet[T]).Add(values...)
}
func (s GenericSet[T]) Add(values...T) GenericSet[T] {
for _, item := range values {
s[item] = void{}
}
return s
}
func (s GenericSet[T]) Remove(values...T) GenericSet[T] {
for _, item := range values {
delete(s, item)
}
return s
}
func (s GenericSet[T]) Has(value T) bool {
_, ok := s[value]
return ok
}
func (s GenericSet[T]) HasAll(values ...T) bool {
for _, v := range values {
if !s.Has(v) {
return false
}
}
return true
}
func (s GenericSet[T]) Equals(another GenericSet[T]) bool {
if len(s) != len(another){
return false
} else if len(s) == 0 && len(another) == 0 {
return true
}
for k := range another {
if !s.Has(k) {
return false
}
}
return true
}
func (s GenericSet[T]) Values() []T {
values := make([]T, len(s))
var i int
for item := range s {
values[i] = item
i++
}
return values
}
func (s GenericSet[T]) Copy() GenericSet[T] {
cp := NewGenericSet[T]()
for k := range s {
cp[k] = void{}
}
return cp
}
// MarshalJSON json.Marshaler
func (s GenericSet[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(s.Values())
}
// UnmarshalJSON json.Unmarshaler
func (s *GenericSet[T]) UnmarshalJSON(data []byte) error {
values := make([]T, 0)
if err := json.Unmarshal(data, &values); err != nil {
return err
}
*s = NewGenericSet(values...)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"encoding/json"
"reflect"
"strings"
)
// Reverse will reverse the order of the given slice
func Reverse[T any](input []T) {
for i, j := 0, len(input)-1; i < j; i, j = i+1, j-1 {
input[i], input[j] = input[j], input[i]
}
}
// RemoveStable will remove an element from the slice and keep its order. This
// operation can be potentially costly depending on how large the slice is since
// it needs to shift all elements that appear after index i over by 1.
//
// This function will automatically type itself using type inference.
//
// If the given index is not within the bounds of the slice, then the function will
// panic
func RemoveStable[T any](slice []T, index int) []T {
if index < 0 || index >= len(slice) {
panic("invalid slice index")
}
return append(slice[:index], slice[index+1:]...)
}
// Remove will not keep the ordering of the slice. It has a very fast operation.
// This function will automatically type itself using type inference
//
// intSlice := []int{1, 2, 3, 4}
// intSlice = Remove(intSlice, 1)
// result: {1, 4, 3}
func Remove[T any](slice []T, index int) []T {
if index < 0 || index >= len(slice) {
panic("invalid slice index")
}
slice[index] = slice[len(slice)-1]
return slice[:len(slice)-1]
}
// ConvertSlice attempt to convert []interface{} to []elementType using the first element's type.
// if given slice is empty, or any elements is not the same type of first one, same slice is returned
func ConvertSlice(slice []interface{}) interface{} {
if len(slice) == 0 {
return slice
}
var failed bool
vSlice := reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(slice[0])), len(slice), len(slice))
for i, v := range slice {
rv := reflect.ValueOf(v)
ev := vSlice.Index(i)
if !rv.Type().ConvertibleTo(ev.Type()) {
failed = true
break
}
ev.Set(rv.Convert(ev.Type()))
}
if !failed {
return vSlice.Interface()
}
return slice
}
// CommaSeparatedSlice alias of []string that can deserialize from comma delimited string
type CommaSeparatedSlice []string
// fmt.Stringer
func (s CommaSeparatedSlice) String() string {
return strings.Join(s, ", ")
}
// MarshalText encoding.TextMarshaler
func (s CommaSeparatedSlice) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
// UnmarshalText encoding.TextUnmarshaler
func (s *CommaSeparatedSlice) UnmarshalText(data []byte) error {
if string(data) == "" {
*s = make([]string, 0)
return nil
}
var result []string
split := strings.Split(string(data), ",")
for _, s := range split {
s = strings.TrimSpace(s)
result = append(result, s)
}
*s = result
return nil
}
// UnmarshalJSON json.Unmarshaler
func (s *CommaSeparatedSlice) UnmarshalJSON(data []byte) error {
// first try regular array
var slice []string
if e := json.Unmarshal(data, &slice); e == nil {
*s = slice
return nil
}
// try comma separated format
var str string
if e := json.Unmarshal(data, &str); e != nil {
return e
}
return s.UnmarshalText([]byte(str))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"encoding/json"
"strconv"
"strings"
"unicode"
)
func UnQuote(s string) string {
s = strings.TrimSpace(s)
if len(s) >= 2 {
if c := s[len(s)-1]; s[0] == c && (c == '"' || c == '\'') {
return s[1 : len(s)-1]
}
}
return s
}
func ParseString(s string) interface{} {
// try number before boolean because 1/0 can be parsed to boolean true/false
if numVal, err := strconv.ParseFloat(s, 64); err == nil {
return numVal
} else if boolVal, err := strconv.ParseBool(s); err == nil {
return boolVal
}
// we also support []interface{} and map[string]interface{}
trimmed := strings.TrimSpace(s)
if strings.HasPrefix(trimmed, "{") {
var v map[string]interface{}
if e := json.Unmarshal([]byte(trimmed), &v); e == nil {
return v
}
}
if strings.HasPrefix(trimmed, "[") {
var v []interface{}
if e := json.Unmarshal([]byte(trimmed), &v); e == nil {
return v
}
}
return s
}
const dash = rune('-')
// CamelToSnakeCase convert "camelCase" string to "snake-case"
func CamelToSnakeCase(camelCase string) string {
var converted []rune
for pos, char := range camelCase {
if unicode.IsUpper(char) {
if pos>0 && unicode.IsLower([]rune(camelCase)[pos-1]) {
converted = append(converted, dash)
}
converted = append(converted, unicode.ToLower(char))
} else {
converted = append(converted, unicode.ToLower(char))
}
}
return string(converted)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"errors"
"time"
)
const (
ISO8601Seconds = "2006-01-02T15:04:05Z07:00" //time.RFC3339
ISO8601Milliseconds = "2006-01-02T15:04:05.000Z07:00"
)
var (
MaxTime = time.Unix(1<<63-1, 0).UTC()
)
func ParseTimeISO8601(v string) time.Time {
parsed, err := time.Parse(ISO8601Seconds, v)
if err != nil {
return time.Time{}
}
return parsed
}
func ParseTime(layout, v string) time.Time {
parsed, err := time.Parse(layout, v)
if err != nil {
return time.Time{}
}
return parsed
}
func ParseDuration(v string) time.Duration {
parsed, err := time.ParseDuration(v)
if err != nil {
return time.Duration(0)
}
return parsed
}
type Duration time.Duration
// MarshalText implements encoding.TextMarshaler
func (d Duration) MarshalText() (text []byte, err error) {
return []byte(time.Duration(d).String()), nil
}
// UnmarshalText implements encoding.TextUnmarshaler
func (d *Duration) UnmarshalText(text []byte) error {
if d == nil {
return errors.New("duration pointer is nil")
}
parsed, e := time.ParseDuration(string(text))
if e != nil {
return e
}
*d = Duration(parsed)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package validation
import (
"fmt"
"github.com/go-playground/validator/v10"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
)
// CaseInsensitiveOneOf validator function that similar to validator.isOneOf but case-insensitive
func CaseInsensitiveOneOf() validator.Func {
return func(fl validator.FieldLevel) bool {
vals := parseOneOfParam2(fl.Param())
field := fl.Field()
var v string
switch field.Kind() {
case reflect.String:
v = field.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v = strconv.FormatInt(field.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v = strconv.FormatUint(field.Uint(), 10)
default:
panic(fmt.Sprintf("Bad field type %T", field.Interface()))
}
for i := 0; i < len(vals); i++ {
if strings.EqualFold(vals[i], v) {
return true
}
}
return false
}
}
var splitParamsRegex = regexp.MustCompile(`'[^']*'|\S+`)
var oneofValsCache = map[string][]string{}
var oneofValsCacheRWLock = sync.RWMutex{}
func parseOneOfParam2(s string) []string {
oneofValsCacheRWLock.RLock()
vals, ok := oneofValsCache[s]
oneofValsCacheRWLock.RUnlock()
if !ok {
oneofValsCacheRWLock.Lock()
vals = splitParamsRegex.FindAllString(s, -1)
for i := 0; i < len(vals); i++ {
vals[i] = strings.Replace(vals[i], "'", "", -1)
}
oneofValsCache[s] = vals
oneofValsCacheRWLock.Unlock()
}
return vals
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package validation
import (
"context"
"encoding"
"fmt"
"github.com/go-playground/validator/v10"
"regexp"
)
func Regex(pattern string) validator.FuncCtx {
return regex(regexp.MustCompile(pattern))
}
func RegexPOSIX(pattern string) validator.FuncCtx {
return regex(regexp.MustCompilePOSIX(pattern))
}
func regex(compiled *regexp.Regexp ) validator.FuncCtx {
return func(_ context.Context, fl validator.FieldLevel) bool {
i := fl.Field().Interface()
var str string
switch v := i.(type) {
case string:
str = v
case *string:
if v != nil {
str = *v
}
case fmt.Stringer:
str = v.String()
case encoding.TextMarshaler:
bytes, _ := v.MarshalText()
str = string(bytes)
default:
// we don't validate non string, just fail it
return false
}
return compiled.MatchString(str)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package validation
import (
"context"
"encoding"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/go-playground/validator/v10"
)
func TenantAccess() validator.FuncCtx {
return func(ctx context.Context, fl validator.FieldLevel) bool {
i := fl.Field().Interface()
var str string
switch v := i.(type) {
case string:
str = v
case *string:
if v != nil {
str = *v
}
case fmt.Stringer:
str = v.String()
case encoding.TextMarshaler:
bytes, _ := v.MarshalText()
str = string(bytes)
default:
// we don't validate non string, just fail it
return false
}
return security.HasAccessToTenant(ctx, str)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package validation
import (
"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
)
const (
DefaultLocale = "en"
)
var (
universalTranslator = newUniversalTranslator()
)
// DefaultTranslator returns the default ut.Translator of the package
func DefaultTranslator() ut.Translator {
trans, _ := universalTranslator.GetTranslator(DefaultLocale)
return trans
}
// UniversalTranslator returns the globally configured ut.UniversalTranslatorTranslator
// callers can register more locales
func UniversalTranslator() *ut.UniversalTranslator {
return universalTranslator
}
// SimpleTranslationRegFunc returns a translation registration function for simple validation translate template
// the returned function could be used to register custom translation override
func SimpleTranslationRegFunc(tag, template string) func(*validator.Validate, ut.Translator) error {
return func(validate *validator.Validate, trans ut.Translator) error {
return validate.RegisterTranslation(tag, trans, func(ut ut.Translator) error {
return ut.Add(tag, template, true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T(tag, fe.Field(), fe.Param())
return t
})
}
}
func newUniversalTranslator() *ut.UniversalTranslator {
english := en.New()
return ut.New(english)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package utils
import (
"fmt"
"github.com/google/uuid"
"reflect"
)
type primitives interface {
~bool |
~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64 | ~complex128 |
~string |
~int | ~uint | ~uintptr
}
// MustSetIfNotNil takes "src" pointer (e.g. *bool) and set its dereference value to "dst" if not nil
// this function panic if:
// - "dst" and "src" are not pointer
// - "src" is not convertable to "dst"
// - "dst" not point to a settable value
func MustSetIfNotNil(dst interface{}, src interface{}) {
dstV := reflect.ValueOf(dst)
srcV := reflect.ValueOf(src)
if srcV.IsNil() {
return
}
dstEV := dstV.Elem()
srcEV := srcV.Elem()
dstEV.Set(srcEV.Convert(dstEV.Type()))
}
// SetIfNotNil is equivalent of MustSetIfNotNil, this function returns error instead of panic
func SetIfNotNil(dst interface{}, src interface{}) (err error) {
defer func() {
switch e := recover().(type) {
case error:
err = e
default:
err = fmt.Errorf("%v", e)
}
}()
MustSetIfNotNil(dst, src)
return
}
// MustSetIfNotZero takes "src" value (e.g. bool) and set its value to "dst" if not zero
// this function panic if:
// - "dst" is not pointer or not point to a settable value
// - "src" is not convertable to "dst"
func MustSetIfNotZero(dst interface{}, src interface{}) {
dstV := reflect.ValueOf(dst)
srcV := reflect.ValueOf(src)
if srcV.IsZero() {
return
}
dstEV := dstV.Elem()
dstEV.Set(srcV.Convert(dstEV.Type()))
}
// SetIfNotZero is equivalent of MustSetIfNotZero, this function returns error instead of panic
func SetIfNotZero(dst interface{}, src interface{}) (err error) {
defer func() {
switch e := recover().(type) {
case error:
err = e
default:
err = fmt.Errorf("%v", e)
}
}()
MustSetIfNotZero(dst, src)
return
}
var (
TRUE = true
FALSE = false
)
// FromPtr will take a pointer type and return its value if it is not nil. Otherwise,
// it will return the default value for that type. ex,
// var s *string
// FromPtr(s) // results in ""
// *s = "hello"
// FromPtr(s) // results in "hello"
// var b *bool
// FromPtr(b) // results in false
// ...
// // Custom Types with underlying types of primitives
// type String string
// var s *String
// FromPtr(s) // results in "" - but typed String
// *s = String("hello")
// FromPtr(s) // results in "hello" - but typed String
func FromPtr[T primitives](t *T) T {
if t != nil {
return *t
}
var defaultValueOfTypeT T
return defaultValueOfTypeT
}
// ToPtr will return a pointer to any given input
// Example usage:
//
// var stringPtr *string
// stringPtr = ToPtr("hello world")
//
// // or some complex types
// var funcPtr *[]func(arg *argType)
// funcPtr = ToPtr([]func(arg *argType){})
func ToPtr[T any](t T) *T {
return &t
}
// BoolPtr
// Deprecated: make use of ToPtr instead
func BoolPtr(v bool) *bool {
if v {
return &TRUE
} else {
return &FALSE
}
}
// IntPtr
// Deprecated: make use of ToPtr instead
func IntPtr(v int) *int {
return &v
}
// UIntPtr
// Deprecated: make use of ToPtr instead
func UIntPtr(v uint) *uint {
return &v
}
// Float64Ptr
// Deprecated: make use of ToPtr instead
func Float64Ptr(v float64) *float64 {
return &v
}
// StringPtr
// Deprecated: Make use of ToPtr instead
func StringPtr(v string) *string {
return &v
}
// UuidPtr
// Deprecated: make use of ToPtr instead
func UuidPtr(v uuid.UUID) *uuid.UUID {
return &v
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package xsync
import (
"context"
"sync"
"sync/atomic"
"unsafe"
)
// Cond is similar to sync.Cond: Conditional variable implementation that uses channels for notifications.
// This implementation differ from sync.Cond in following ways:
// - Only supports Broadcast
// - Wait with ctx that can be cancelled
// see sync.Cond for usage
// Ref: https://gist.github.com/zviadm/c234426882bfc8acba88f3503edaaa36#file-cond2-go
type Cond struct {
L sync.Locker
n unsafe.Pointer
}
func NewCond(l sync.Locker) *Cond {
c := &Cond{L: l}
n := make(chan struct{})
c.n = unsafe.Pointer(&n)
return c
}
// Wait for Broadcast calls. Similar to regular sync.Cond, this unlocks the underlying
// locker first, waits on changes and re-locks it before returning.
func (c *Cond) Wait(ctx context.Context) (err error) {
n := c.notifyChan()
c.L.Unlock()
select {
case <-n:
case <-ctx.Done():
err = ctx.Err()
}
c.L.Lock()
return
}
// Broadcast call notifies everyone that something has changed.
func (c *Cond) Broadcast() {
n := make(chan struct{})
ptrOld := atomic.SwapPointer(&c.n, unsafe.Pointer(&n))
close(*(*chan struct{})(ptrOld))
}
// notifyChan Returns a channel that can be used to wait for next Broadcast() call.
func (c *Cond) notifyChan() <-chan struct{} {
ptr := atomic.LoadPointer(&c.n)
return *((*chan struct{})(ptr))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaultappconfig
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
appconfiginit "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/vault"
)
type ProviderGroupOptions func(opt *ProviderGroupOption)
type ProviderGroupOption struct {
Precedence int
Backend string
BackendVersion int
Path string
ProfileSeparator string
VaultClient *vault.Client
}
// NewProviderGroup create a Vault KV engine backed appconfig.ProviderGroup.
// The provider group is responsible to load application properties from Vault KV engine at paths:
// <ProviderGroupOption.Backend>/<ProviderGroupOption.Path>[<ProviderGroupOption.ProfileSeparator><any active profile>]
// e.g.
// - "secret/defaultapplication"
// - "secret/defaultapplication/prod"
// - "secret/my-service"
// - "secret/my-service/staging"
func NewProviderGroup(opts ...ProviderGroupOptions) (appconfig.ProviderGroup, error) {
opt := ProviderGroupOption{
Precedence: appconfiginit.PrecedenceExternalDefaultContext,
Backend: DefaultBackend,
BackendVersion: DefaultBackendVersion,
Path: DefaultConfigPath,
ProfileSeparator: DefaultProfileSeparator,
}
for _, fn := range opts {
fn(&opt)
}
kvSecretEngine, e := NewKvSecretEngine(opt.BackendVersion, opt.Backend, opt.VaultClient)
if e != nil {
return nil, e
}
group := appconfig.NewProfileBasedProviderGroup(opt.Precedence)
group.KeyFunc = func(profile string) string {
if profile == "" {
return opt.Path
}
return fmt.Sprintf("%s%s%s", opt.Path, opt.ProfileSeparator, profile)
}
group.CreateFunc = func(name string, order int, _ bootstrap.ApplicationConfig) appconfig.Provider {
return NewVaultKvProvider(order, name, kvSecretEngine)
}
return group, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaultappconfig
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/vault"
)
type KvSecretEngine interface {
ContextPath(secretPath string) string
ListSecrets(ctx context.Context, secretPath string) (results map[string]interface{}, err error)
}
func NewKvSecretEngine(version int, backend string, client *vault.Client) (KvSecretEngine, error) {
switch version {
case 1:
return &KvSecretEngineV1{
backend: backend,
client: client,
}, nil
default :
return nil, errors.New("unsupported kv secret engine version")
}
}
type KvSecretEngineV1 struct {
client *vault.Client
backend string
}
// ContextPath
//key value v1 API expects GET /secret/:path (as opposed to the v2 API which expects GET /secret/data/:path?version=:version-number)
func (engine *KvSecretEngineV1) ContextPath(secretPath string) string {
return fmt.Sprintf("%s/%s", engine.backend, secretPath)
}
// ListSecrets implements KvSecretEngine
/*
Vault key value v1 API has the following response
we return the kv in the data field
{
"auth": null,
"data": {
"foo": "bar",
"ttl": "1h"
},
"lease_duration": 3600,
"lease_id": "",
"renewable": false
}
as opposed to the v2 API where the response is
{
"data": {
"data": {
"foo": "bar"
},
"metadata": {
"created_time": "2018-03-22T02:24:06.945319214Z",
"deletion_time": "",
"destroyed": false,
"version": 2
}
}
}
*/
func (engine *KvSecretEngineV1) ListSecrets(ctx context.Context, secretPath string) (results map[string]interface{}, err error) {
path := engine.ContextPath(secretPath)
results = make(map[string]interface{})
//nolint:contextcheck // false positive
if secrets, err := engine.client.Logical(ctx).Read(path); err != nil {
return nil, err
} else if secrets != nil {
for key, val := range secrets.Data {
results[key] = utils.ParseString(val.(string))
}
}
return results, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaultappconfig
import (
"github.com/cisco-open/go-lanai/pkg/appconfig"
appconfiginit "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/vault"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "bootstrap endpoint",
Precedence: bootstrap.AppConfigPrecedence,
PriorityOptions: []fx.Option{
fx.Provide(
bindVaultConfigProperties,
fxNewVaultDefaultContextProviderGroup,
fxNewVaultAppContextProviderGroup,
),
},
}
type groupDI struct {
fx.In
BootstrapConfig *appconfig.BootstrapConfig
VaultConfigProperties VaultConfigProperties
VaultClient *vault.Client `optional:"true"`
}
type appConfigProvidersOut struct {
fx.Out
ProviderGroup appconfig.ProviderGroup `group:"application-config"`
}
func withProperties(props *VaultConfigProperties) ProviderGroupOptions {
return func(opt *ProviderGroupOption) {
opt.Backend = props.Backend
opt.BackendVersion = props.BackendVersion
opt.Path = props.DefaultContext
opt.ProfileSeparator = props.ProfileSeparator
}
}
func fxNewVaultDefaultContextProviderGroup(di groupDI) (appConfigProvidersOut, error) {
if !di.VaultConfigProperties.Enabled || di.VaultClient == nil {
return appConfigProvidersOut{}, nil
}
group, e := NewProviderGroup(withProperties(&di.VaultConfigProperties),
func(opt *ProviderGroupOption) {
opt.Precedence = appconfiginit.PrecedenceExternalDefaultContext
opt.VaultClient = di.VaultClient
},
)
out := appConfigProvidersOut{
ProviderGroup: group,
}
return out, e
}
func fxNewVaultAppContextProviderGroup(di groupDI) (appConfigProvidersOut, error) {
if !di.VaultConfigProperties.Enabled || di.VaultClient == nil {
return appConfigProvidersOut{}, nil
}
appName, _ := di.BootstrapConfig.Value(bootstrap.PropertyKeyApplicationName).(string)
group, e := NewProviderGroup(withProperties(&di.VaultConfigProperties),
func(opt *ProviderGroupOption) {
opt.Precedence = appconfiginit.PrecedenceExternalAppContext
opt.Path = appName
opt.VaultClient = di.VaultClient
},
)
out := appConfigProvidersOut{
ProviderGroup: group,
}
return out, e
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaultappconfig
import "github.com/cisco-open/go-lanai/pkg/appconfig"
const (
PropertiesPrefix = "cloud.vault.kv"
DefaultBackend = `secret`
DefaultBackendVersion = 1
DefaultConfigPath = "defaultapplication"
DefaultProfileSeparator = "/"
)
// VaultConfigProperties currently only supports v1 kv secret engine
// TODO review property path and prefix
type VaultConfigProperties struct {
Enabled bool `json:"enabled"`
Backend string `json:"backend"`
BackendVersion int `json:"backend-version"`
DefaultContext string `json:"default-context"`
ProfileSeparator string `json:"profile-separator"`
}
func bindVaultConfigProperties(bootstrapConfig *appconfig.BootstrapConfig) VaultConfigProperties {
p := VaultConfigProperties{
Enabled: true,
Backend: DefaultBackend,
BackendVersion: DefaultBackendVersion,
DefaultContext: DefaultConfigPath,
ProfileSeparator: DefaultProfileSeparator,
}
if e := bootstrapConfig.Bind(&p, PropertiesPrefix); e != nil {
panic(e)
}
return p
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaultappconfig
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/appconfig"
"github.com/cisco-open/go-lanai/pkg/log"
)
var logger = log.New("Config.Vault")
// KeyValueConfigProvider
//Vault kv v1 differs with v2 API both in how the context path is constructed and how the response is parsed.
//https://www.vaultproject.io/docs/secrets/kv/kv-v1
type KeyValueConfigProvider struct {
appconfig.ProviderMeta
secretPath string
secretEngine KvSecretEngine
}
func (p *KeyValueConfigProvider) Name() string {
return fmt.Sprintf("vault:%s", p.secretEngine.ContextPath(p.secretPath))
}
func (p *KeyValueConfigProvider) Load(ctx context.Context) (loadError error) {
defer func(){
if loadError != nil {
p.Loaded = false
} else {
p.Loaded = true
}
}()
p.Settings = make(map[string]interface{})
// load keys from default context
var defaultSettings map[string]interface{}
defaultSettings, loadError = p.secretEngine.ListSecrets(ctx, p.secretPath)
if loadError != nil {
return loadError
}
unFlattenedSettings, loadError := appconfig.UnFlatten(defaultSettings)
if loadError != nil {
return loadError
}
p.Settings = unFlattenedSettings
logger.WithContext(ctx).Infof("Retrieved %d secrets from vault path: %s", len(defaultSettings), p.secretEngine.ContextPath(p.secretPath))
return nil
}
func NewVaultKvProvider(precedence int, secretPath string, secretEngine KvSecretEngine) *KeyValueConfigProvider {
return &KeyValueConfigProvider{
ProviderMeta: appconfig.ProviderMeta{Precedence: precedence},
secretPath: secretPath,
secretEngine: secretEngine,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import "github.com/hashicorp/vault/api"
//ClientAuthentication interface represents a vault auth method https://www.vaultproject.io/docs/auth
type ClientAuthentication interface {
Login(client *api.Client) (token string, err error)
}
func newClientAuthentication(p *ConnectionProperties) ClientAuthentication {
var clientAuthentication ClientAuthentication
switch p.Authentication {
case Kubernetes:
clientAuthentication = TokenKubernetesAuthentication(p.Kubernetes)
case Token:
fallthrough
default:
clientAuthentication = TokenClientAuthentication(p.Token)
}
return clientAuthentication
}
type TokenClientAuthentication string
func (d TokenClientAuthentication) Login(client *api.Client) (token string, err error) {
return string(d), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import (
"context"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/api/auth/kubernetes"
)
type KubernetesClient struct {
config KubernetesConfig
}
func (c *KubernetesClient) Login(client *api.Client) (string, error) {
var options []kubernetes.LoginOption
// defaults to using /var/run/secrets/kubernetes.io/serviceaccount/token if no options set
if c.config.JWTPath != "" {
options = append(options, kubernetes.WithServiceAccountTokenPath(c.config.JWTPath))
}
k8sAuth, err := kubernetes.NewKubernetesAuth(
c.config.Role,
options...,
)
if err != nil {
return "", err
}
authInfo, err := client.Auth().Login(context.Background(), k8sAuth)
if err != nil {
return "", err
}
return authInfo.Auth.ClientToken, nil
}
func TokenKubernetesAuthentication(kubernetesConfig KubernetesConfig) *KubernetesClient {
return &KubernetesClient{
config: kubernetesConfig,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import (
"context"
"encoding/json"
"errors"
"github.com/hashicorp/vault/api"
"sync"
)
var (
errTokenNotRenewable = errors.New("token is not renewable")
)
type Options func(cfg *ClientConfig) error
type ClientConfig struct {
// Config raw config of vault driver
*api.Config
// Properties from bootstrap.BootstrapConfig. Typically set via WithProperties()
Properties ConnectionProperties
// ClientAuth used by the client and internal token refresher to authenticate with Vault server
ClientAuth ClientAuthentication
// Hooks instrumentation points
Hooks []Hook
}
func WithProperties(p ConnectionProperties) Options {
return func(cfg *ClientConfig) error {
cfg.Properties = p
cfg.ClientAuth = newClientAuthentication(&p)
cfg.Address = p.Address()
if p.Scheme == "https" {
t := api.TLSConfig{
CACert: p.SSL.CaCert,
ClientCert: p.SSL.ClientCert,
ClientKey: p.SSL.ClientKey,
Insecure: p.SSL.Insecure,
}
err := cfg.ConfigureTLS(&t)
if err != nil {
return err
}
}
return nil
}
}
type Client struct {
*api.Client
properties ConnectionProperties
clientAuth ClientAuthentication
hooks []Hook
mu sync.Mutex // mutex protect fields below
refresher *TokenRefresher
}
func New(opts ...Options) (*Client, error) {
cfg := ClientConfig{
Config: api.DefaultConfig(),
ClientAuth: TokenClientAuthentication(""),
}
for _, fn := range opts {
if e := fn(&cfg); e != nil {
return nil, e
}
}
return newClient(&cfg)
}
func newClient(cfg *ClientConfig) (*Client, error) {
client, err := api.NewClient(cfg.Config)
if err != nil {
return nil, err
}
ret := &Client{
Client: client,
properties: cfg.Properties,
clientAuth: cfg.ClientAuth,
hooks: cfg.Hooks,
}
if err = ret.Authenticate(); err != nil {
logger.Warnf("vault client cannot get token %v", err)
}
return ret, nil
}
func (c *Client) Authenticate() error {
token, err := c.clientAuth.Login(c.Client)
if err != nil {
return err
}
c.Client.SetToken(token)
return nil
}
func (c *Client) AddHooks(_ context.Context, hooks ...Hook) {
c.hooks = append(c.hooks, hooks...)
}
func (c *Client) Logical(ctx context.Context) *Logical {
return &Logical{
Logical: c.Client.Logical(),
ctx: ctx,
client: c,
}
}
func (c *Client) Sys(ctx context.Context) *Sys {
return &Sys{
Sys: c.Client.Sys(),
ctx: ctx,
client: c,
}
}
// AutoRenewToken start a TokenRefresher to automatically manage and renew vault token
func (c *Client) AutoRenewToken(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
if c.refresher != nil {
return
}
c.refresher = NewTokenRefresher(c)
c.refresher.Start(ctx)
}
// TokenRenewer returns api.Renewer for manual Token management.
// Use AutoRenewToken auto-renew
func (c *Client) TokenRenewer() (*api.Renewer, error) {
secret, err := c.Client.Auth().Token().LookupSelf()
if err != nil {
return nil, err
}
var renewable bool
if v, ok := secret.Data["renewable"]; ok {
renewable, _ = v.(bool)
}
if !renewable {
return nil, errTokenNotRenewable
}
var increment int64
if v, ok := secret.Data["ttl"]; ok {
if n, ok := v.(json.Number); ok {
increment, _ = n.Int64()
}
}
return c.Client.NewLifetimeWatcher(&api.LifetimeWatcherInput{
Secret: &api.Secret{
Auth: &api.SecretAuth{
ClientToken: c.Client.Token(),
Renewable: renewable,
},
},
Increment: int(increment),
})
}
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.refresher != nil {
c.refresher.Stop()
c.refresher = nil
}
return nil
}
// Clone make a copy of current Client with given customizations
func (c *Client) Clone(opts ...Options) (*Client, error) {
cfg := ClientConfig{
Config: c.Client.CloneConfig(),
Properties: c.properties,
ClientAuth: c.clientAuth,
Hooks: make([]Hook, len(c.hooks)),
}
copy(cfg.Hooks, c.hooks)
for _, fn := range opts {
if e := fn(&cfg); e != nil {
return nil, e
}
}
return newClient(&cfg)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import "strings"
const (
Token = AuthMethod("token")
Kubernetes = AuthMethod("kubernetes")
)
var refreshable = map[AuthMethod]struct{}{
Kubernetes: {},
}
type AuthMethod string
// UnmarshalText encoding.TextUnmarshaler
func (a *AuthMethod) UnmarshalText(data []byte) error {
*a = AuthMethod(strings.ToUpper(string(data)))
return nil
}
func (a AuthMethod) isRefreshable() bool {
_, ok := refreshable[a]
return ok
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaulthealth
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
"github.com/cisco-open/go-lanai/pkg/vault"
"go.uber.org/fx"
)
type HealthRegDI struct {
fx.In
HealthRegistrar health.Registrar `optional:"true"`
VaultClient *vault.Client `optional:"true"`
}
func Register(di HealthRegDI) error {
if di.HealthRegistrar == nil || di.VaultClient == nil {
return nil
}
return di.HealthRegistrar.Register(New(di.VaultClient))
}
func New(client *vault.Client) *HealthIndicator {
return &HealthIndicator{Client: client}
}
type HealthIndicator struct {
Client *vault.Client
}
func (i *HealthIndicator) Name() string {
return "vault"
}
func (i *HealthIndicator) Health(c context.Context, options health.Options) health.Health {
if _, e := i.Client.Sys(c).Health(); e != nil {
return health.NewDetailedHealth(health.StatusDown, "vault /v1/sys/health failed", nil)
} else {
return health.NewDetailedHealth(health.StatusUp, "vault /v1/sys/health succeeded", nil)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import (
"context"
"embed"
"github.com/cisco-open/go-lanai/pkg/appconfig"
appconfigInit "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/vault"
vaultappconfig "github.com/cisco-open/go-lanai/pkg/vault/appconfig"
vaulthealth "github.com/cisco-open/go-lanai/pkg/vault/health"
vaulttracing "github.com/cisco-open/go-lanai/pkg/vault/tracing"
"go.uber.org/fx"
)
//go:embed defaults-vault.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "vault",
Precedence: bootstrap.VaultPrecedence,
PriorityOptions: []fx.Option{
fx.Provide(BindConnectionProperties, ProvideDefaultClient),
},
Options: []fx.Option{
appconfigInit.FxEmbeddedDefaults(defaultConfigFS),
fx.Invoke(vaulthealth.Register, manageClientLifecycle),
},
Modules: []*bootstrap.Module{
vaultappconfig.Module,
vaulttracing.Module,
},
}
// Use func, does nothing. Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
}
func BindConnectionProperties(bootstrapConfig *appconfig.BootstrapConfig) (vault.ConnectionProperties, error) {
c := vault.ConnectionProperties{
Host: "localhost",
Port: 8200,
Scheme: "http",
Authentication: vault.Token,
Token: "replace_with_token_value",
}
if e := bootstrapConfig.Bind(&c, vault.PropertyPrefix); e != nil {
return c, e
}
return c, nil
}
type clientDI struct {
fx.In
Props vault.ConnectionProperties
Customizers []vault.Options `group:"vault"`
}
func ProvideDefaultClient(di clientDI) *vault.Client {
opts := append([]vault.Options{
vault.WithProperties(di.Props),
}, di.Customizers...)
client, err := vault.New(opts...)
if err != nil {
panic(err)
}
return client
}
type lcDI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Lifecycle fx.Lifecycle
VaultClient *vault.Client `optional:"true"`
}
func manageClientLifecycle(di lcDI) {
if di.VaultClient == nil {
return
}
di.Lifecycle.Append(fx.StartHook(func(_ context.Context) {
//nolint:contextcheck // Non-inherited new context - intentional. Start hook context expires when startup finishes
di.VaultClient.AutoRenewToken(di.AppCtx)
}))
di.Lifecycle.Append(fx.StopHook(func(_ context.Context) error {
return di.VaultClient.Close()
}))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import (
"context"
"fmt"
"github.com/hashicorp/vault/api"
"io"
"net/http"
"strings"
)
type Logical struct {
*api.Logical
ctx context.Context
client *Client
}
// WithContext make a copy of current Logical with a new context
func (l *Logical) WithContext(ctx context.Context) *Logical {
if ctx == nil {
panic("nil context is not allowed")
}
return &Logical{
Logical: l.Logical,
ctx: ctx,
client: l.client,
}
}
// Read override api.Logical with proper hooks
func (l *Logical) Read(path string) (ret *api.Secret, err error) {
ctx := l.beforeOp(l.ctx, "Read", path)
defer func() { l.afterOp(ctx, err) }()
ret, err = l.Logical.Read(path)
return
}
// ReadWithData override api.Logical with proper hooks
// Note: data is sent as HTTP parameters
func (l *Logical) ReadWithData(path string, data map[string][]string) (ret *api.Secret, err error) {
ctx := l.beforeOp(l.ctx, "Read", path)
defer func() { l.afterOp(ctx, err) }()
ret, err = l.Logical.ReadWithData(path, data)
return
}
// Write override api.Logical with proper hooks. This method accept data as an interface instead of map
// Note: Write sends PUT request
func (l *Logical) Write(path string, data interface{}) (ret *api.Secret, err error) {
ctx := l.beforeOp(l.ctx, "Write", path)
defer func() { l.afterOp(ctx, err) }()
ret, err = l.writeWithMethod(http.MethodPut, path, data) //nolint:contextcheck
return
}
// Post is extension of api.Logical. Similar to Write, but use POST request
func (l *Logical) Post(path string, data interface{}) (ret *api.Secret, err error) {
ctx := l.beforeOp(l.ctx, "Post", path)
defer func() { l.afterOp(ctx, err) }()
ret, err = l.writeWithMethod(http.MethodPost, path, data) //nolint:contextcheck
return
}
// WriteWithMethod is extension of api.Logical to send POST and PUT request
func (l *Logical) WriteWithMethod(method, path string, data interface{}) (ret *api.Secret, err error) {
ctx := l.beforeOp(l.ctx, method, path)
defer func() { l.afterOp(ctx, err) }()
return l.writeWithMethod(strings.ToUpper(method), path, data) //nolint:contextcheck
}
func (l *Logical) beforeOp(ctx context.Context, name, path string) context.Context {
cmd := fmt.Sprintf("%s %s", name, path)
for _, h := range l.client.hooks {
ctx = h.BeforeOperation(ctx, cmd)
}
return ctx
}
func (l *Logical) afterOp(ctx context.Context, err error) {
for _, h := range l.client.hooks {
h.AfterOperation(ctx, err)
}
}
//nolint:contextcheck // context is bond with struct
func (l *Logical) writeWithMethod(method, path string, data interface{}) (*api.Secret, error) {
switch method {
case http.MethodPost, http.MethodPut:
default:
return nil, fmt.Errorf("invalid HTTP method, only POST and PUT are accepted")
}
r := l.client.NewRequest(method, "/v1/"+path)
if e := r.SetJSONBody(data); e != nil {
return nil, e
}
return l.write(r)
}
//nolint:contextcheck // context is bond with struct
func (l *Logical) write(request *api.Request) (*api.Secret, error) {
ctx, cancelFunc := context.WithCancel(l.ctx)
defer cancelFunc()
//nolint:staticcheck // Deprecated API. TODO should fix
resp, err := l.client.RawRequestWithContext(ctx, request)
if resp != nil {
defer resp.Body.Close()
}
if resp != nil && resp.StatusCode == 404 {
secret, parseErr := api.ParseSecret(resp.Body)
switch parseErr {
case nil:
case io.EOF:
return nil, nil
default:
return nil, err
}
if secret != nil && (len(secret.Warnings) > 0 || len(secret.Data) > 0) {
return secret, err
}
}
if err != nil {
return nil, err
}
return api.ParseSecret(resp.Body)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import "fmt"
const (
PropertyPrefix = "cloud.vault"
)
type ConnectionProperties struct {
Host string `json:"host"`
Port int `json:"port"`
Scheme string `json:"scheme"`
Authentication AuthMethod `json:"authentication"`
SSL SSLProperties `json:"ssl"`
Kubernetes KubernetesConfig `json:"kubernetes"`
Token string `json:"token"`
}
func (p ConnectionProperties) Address() string {
return fmt.Sprintf("%s://%s:%d", p.Scheme, p.Host, p.Port)
}
type SSLProperties struct {
CaCert string `json:"ca-cert"`
ClientCert string `json:"client-cert"`
ClientKey string `json:"client-key"`
Insecure bool `json:"insecure"`
}
type KubernetesConfig struct {
JWTPath string `json:"jwt-path"`
Role string `json:"role"`
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import (
"context"
"errors"
"github.com/hashicorp/vault/api"
"sync"
"time"
)
// TokenRefresher performs renewal & refreshment of a client's token
// renewal can occur when a token's ttl is completed,
// refresh occurs when a token cannot be renewed (e.g max TTL is reached)
type TokenRefresher struct {
client *Client
renewer *api.Renewer
cancelFunc context.CancelFunc
cancelLock sync.Mutex
}
const renewerDescription = "vault client token"
func NewTokenRefresher(client *Client) *TokenRefresher {
return &TokenRefresher{
client: client,
}
}
// Start will begin the processes of token renewal & refreshing
func (r *TokenRefresher) Start(ctx context.Context) {
r.cancelLock.Lock()
defer r.cancelLock.Unlock()
if r.cancelFunc != nil {
return
}
ctx, r.cancelFunc = context.WithCancel(ctx)
//this starts a background process to log the renewal events.
go r.monitorRenew(ctx)
}
// Stop will stop the token renewal/refreshing processes
func (r *TokenRefresher) Stop() {
r.cancelLock.Lock()
defer r.cancelLock.Unlock()
if r.cancelFunc != nil {
r.cancelFunc()
r.cancelFunc = nil
}
}
func (r *TokenRefresher) isRefreshable() bool {
return r.client.properties.Authentication.isRefreshable()
}
// Starts a blocking process to monitor if the token stops being renewed
// If so, it will refresh the token (if refreshable) and restart renewing process
func (r *TokenRefresher) monitorRenew(ctx context.Context) {
for {
if r.renewer == nil {
// If the token expires or if the lease is revoked
// Sleep for some time and see if the token valid now (i.e if the token is recreated by vault)
for {
var err error
if r.renewer, err = r.client.TokenRenewer(); err == nil {
break
} else if !errors.Is(err, errTokenNotRenewable) {
// Don't want to spam this message if the user is using a static token (where renewals aren't needed)
logger.WithContext(ctx).Debugf("%s unable to create token renewer, %v", renewerDescription, err)
}
time.Sleep(5 * time.Minute)
}
// Starts a blocking process to periodically renew the token.
go r.renewer.Start()
}
select {
case renewal := <-r.renewer.RenewCh():
logger.WithContext(ctx).Debugf("%s successfully renewed at %v", renewerDescription, renewal.RenewedAt)
case err := <-r.renewer.DoneCh():
r.renewer = nil
switch {
case !r.isRefreshable():
// When authentication is token, and if the token expires, we can't really do anything on the client side
// Do not quit the renewer in the hopes that the token is recreated & we can resume
logger.WithContext(ctx).Warnf("%s renewer stopped for non-refreshable authentication: %v", renewerDescription, err)
break
case err != nil:
logger.WithContext(ctx).Infof("%s renewer stopped with error, will re-authenticate & restart: %v", renewerDescription, err)
default:
logger.WithContext(ctx).Debugf("%s renewer stopped, will re-authenticate & restart", renewerDescription)
}
err = r.client.Authenticate()
if err != nil {
logger.WithContext(ctx).Errorf("Could not get a new token: %v", err)
break
}
case <-ctx.Done():
r.renewer.Stop()
r.renewer = nil
return
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vaulttracing
import (
"context"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
const opName = "vault"
type Hook struct {
tracer opentracing.Tracer
}
func NewHook(tracer opentracing.Tracer) *Hook {
return &Hook{
tracer: tracer,
}
}
func (h *Hook) BeforeOperation(ctx context.Context, cmd string) context.Context {
name := opName + " " + cmd
opts := []tracing.SpanOption{
tracing.SpanKind(ext.SpanKindRPCClientEnum),
tracing.SpanTag("cmd", cmd),
}
return tracing.WithTracer(h.tracer).
WithOpName(name).
WithOptions(opts...).
DescendantOrNoSpan(ctx)
}
func (h *Hook) AfterOperation(ctx context.Context, err error) {
op := tracing.WithTracer(h.tracer)
if err != nil {
op.WithOptions(tracing.SpanTag("err", err))
}
op.Finish(ctx)
}
package vaulttracing
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/vault"
"github.com/opentracing/opentracing-go"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "vault-tracing",
Precedence: bootstrap.TracingPrecedence,
PriorityOptions: []fx.Option{
fx.Invoke(initialize),
},
}
type tracerDI struct {
fx.In
AppContext *bootstrap.ApplicationContext
Tracer opentracing.Tracer `optional:"true"`
VaultClient *vault.Client `optional:"true"`
// we could include security configurations, customizations here
}
func initialize(di tracerDI) {
// vault instrumentation
if di.Tracer != nil && di.VaultClient != nil {
hook := NewHook(di.Tracer)
di.VaultClient.AddHooks(di.AppContext, hook)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package vault
import (
"context"
"encoding/base64"
"fmt"
"github.com/hashicorp/vault/api"
"net/url"
)
const (
pathTmplCreateKey = `transit/keys/%s`
pathTmplEncrypt = `transit/encrypt/%s`
pathTmplDecrypt = `transit/decrypt/%s`
)
const (
defaultTransitKeyType = "aes256-gcm96"
respKeyCipherText = "ciphertext"
respKeyPlainText = "plaintext"
)
type TransitEngine interface {
PrepareKey(ctx context.Context, kid string) error
Encrypt(ctx context.Context, kid string, plaintext []byte) ([]byte, error)
Decrypt(ctx context.Context, kid string, cipher []byte) ([]byte, error)
}
type KeyOptions func(opt *KeyOption)
type KeyOption struct {
KeyType string
Exportable bool
AllowPlaintextBackup bool
}
type transit struct {
c *Client
keyType string
exportable bool
allowPlaintextBk bool
}
func NewTransitEngine(client *Client, opts ...KeyOptions) TransitEngine {
opt := KeyOption{
KeyType: defaultTransitKeyType,
}
for _, fn := range opts {
fn(&opt)
}
if opt.KeyType == "" {
opt.KeyType = defaultTransitKeyType
}
return &transit{
c: client,
keyType: opt.KeyType,
exportable: opt.Exportable,
allowPlaintextBk: opt.AllowPlaintextBackup,
}
}
func (t *transit) PrepareKey(ctx context.Context, kid string) error {
path := fmt.Sprintf(pathTmplCreateKey, url.PathEscape(kid))
req := transitCreateKey{
Type: t.keyType,
Exportable: t.exportable,
AllowPlaintextBackup: t.allowPlaintextBk,
}
//nolint:contextcheck
if _, e := t.c.Logical(ctx).Post(path, &req); e != nil {
return e
}
return nil
}
func (t *transit) Encrypt(ctx context.Context, kid string, plaintext []byte) ([]byte, error) {
path := fmt.Sprintf(pathTmplEncrypt, url.PathEscape(kid))
b64 := base64.StdEncoding.EncodeToString(plaintext)
req := transitEncrypt{
PlaintextB64: b64,
}
s, e := t.c.Logical(ctx).Post(path, &req) //nolint:contextcheck
if e != nil {
return nil, e
}
ciphertext, e := t.extractString(s, respKeyCipherText)
return []byte(ciphertext), e
}
func (t *transit) Decrypt(ctx context.Context, kid string, cipher []byte) ([]byte, error) {
path := fmt.Sprintf(pathTmplDecrypt, url.PathEscape(kid))
req := transitDecrypt{
Ciphertext: string(cipher),
}
s, e := t.c.Logical(ctx).Post(path, &req) //nolint:contextcheck
if e != nil {
return nil, e
}
plaintextB64, e := t.extractString(s, respKeyPlainText)
if e != nil {
return nil, e
}
return base64.StdEncoding.DecodeString(plaintextB64)
}
func (t *transit) post(ctx context.Context, path string, reqData interface{}) (ret *api.Secret, err error) {
ret, err = t.c.Logical(ctx).Post(path, reqData) //nolint:contextcheck
switch {
case err != nil:
return
case ret.Data == nil:
return nil, fmt.Errorf("missing data in vault response")
}
return
}
func (t *transit) extractString(s *api.Secret, key string) (string, error) {
if s.Data == nil {
return "", fmt.Errorf("missing data in vault response")
}
v, ok := s.Data[key]
if !ok {
return "", fmt.Errorf("missing %s in vault response data", key)
}
text, ok := v.(string)
if !ok {
return "", fmt.Errorf("invalid type of %s in vault response data, expected string but got %T", key, v)
}
return text, nil
}
/*************************
Requests
*************************/
// transitCreateKey is a subset of all supported request parameters of `POST transit/keys/:name`
// see https://www.vaultproject.io/api/secret/transit#create-key
type transitCreateKey struct {
Type string `json:"type"`
Exportable bool `json:"exportable,omitempty"`
AllowPlaintextBackup bool `json:"allow_plaintext_backup,omitempty"`
}
// transitEncrypt is a subset of all supported request parameters of `POST /transit/encrypt/:name`
// see https://www.vaultproject.io/api/secret/transit#encrypt-data
type transitEncrypt struct {
PlaintextB64 string `json:"plaintext"`
Type string `json:"type,omitempty"`
ContextB64 string `json:"context,omitempty"`
KeyVersion int `json:"key_version,omitempty"`
NonceB64 string `json:"nonce,omitempty"`
}
// transitDecrypt is a subset of all supported request parameters of `POST /transit/decrypt/:name`
// see https://www.vaultproject.io/api/secret/transit#decrypt-data
type transitDecrypt struct {
Ciphertext string `json:"ciphertext"`
ContextB64 string `json:"context,omitempty"`
NonceB64 string `json:"nonce,omitempty"`
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package assets
import (
"github.com/cisco-open/go-lanai/pkg/web"
"net/http"
)
type assetsMapping struct {
path string
root string
aliases map[string]string
}
func New(relativePath string, assetsRootPath string) web.StaticMapping {
return &assetsMapping{
path: relativePath,
root: assetsRootPath,
aliases: map[string]string{},
}
}
/*****************************
StaticMapping Interface
******************************/
func (m *assetsMapping) Name() string {
return m.path
}
func (m *assetsMapping) Path() string {
return m.path
}
func (m *assetsMapping) Method() string {
return http.MethodGet
}
func (m *assetsMapping) StaticRoot() string {
return m.root
}
func (m *assetsMapping) Aliases() map[string]string {
return m.aliases
}
func (m *assetsMapping) AddAlias(path, filePath string) web.StaticMapping {
m.aliases[path] = filePath
return m
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"net/http"
"regexp"
)
// Validation reference: https://godoc.org/github.com/go-playground/validator#hdr-Baked_In_Validators_and_Tags
var (
pathParamPattern, _ = regexp.Compile(`\/:[^\/]*`)
)
/*********************************
Customization
*********************************/
// Customizer is invoked by Registrar at the beginning of initialization,
// customizers can register anything except for additional customizers
// If a customizer retains the given context in anyway, it should also implement PostInitCustomizer to release it
type Customizer interface {
Customize(ctx context.Context, r *Registrar) error
}
// PostInitCustomizer is invoked by Registrar after initialization, register anything in PostInitCustomizer.PostInit
// would cause error or takes no effect
type PostInitCustomizer interface {
Customizer
PostInit(ctx context.Context, r *Registrar) error
}
type EngineOptions func(*Engine)
/*********************************
Request
*********************************/
// DecodeRequestFunc extracts a payload from a http.Request. It's designed to be used by MvcMapping.
// Example of common implementation includes JSON decoder or form data extractor
type DecodeRequestFunc func(ctx context.Context, httpReq *http.Request) (req interface{}, err error)
// RequestRewriter handles request rewrite. e.g. rewrite http.Request.URL.Path
type RequestRewriter interface {
// HandleRewrite take the rewritten request and put it through the entire handling cycle.
// The http.Request.Context() is carried over
// Note: if no error is returned, caller should stop processing the original request and discard the original request
HandleRewrite(rewritten *http.Request) error
}
/*********************************
Response
*********************************/
// EncodeResponseFunc encodes a user response object into http.ResponseWriter. It's designed to be used by MvcMapping.
// Example of common implementation includes JSON encoder or template based HTML generator.
type EncodeResponseFunc func(ctx context.Context, rw http.ResponseWriter, resp interface{}) error
// StatusCoder is an additional interface that a user response object or error could implement.
// EncodeResponseFunc and EncodeErrorFunc should typically check for this interface and manipulate response status code accordingly
type StatusCoder interface {
StatusCode() int
}
// Headerer is an additional interface that a user response object or error could implement.
// EncodeResponseFunc and EncodeErrorFunc should typically check for this interface and manipulate response headers accordingly
type Headerer interface {
Headers() http.Header
}
// BodyContainer is an additional interface that a user response object or error could implement.
// This interface is majorly used internally for mapping
type BodyContainer interface {
Body() interface{}
}
/*********************************
Error Translator
*********************************/
// EncodeErrorFunc is responsible for encoding an error to the ResponseWriter. It's designed to be used by MvcMapping.
// Example of common implementation includes JSON encoder or template based HTML generator.
type EncodeErrorFunc func(ctx context.Context, err error, w http.ResponseWriter)
// ErrorTranslator can be registered via web.Registrar
// it will contribute our MvcMapping's error handling process.
// Note: it won't contribute Middleware's error handling
//
// Implementing Notes:
// 1. if it doesn't handle the error, return same error
// 2. if custom StatusCode is required, make the returned error implement StatusCoder
// 3. if custom Header is required, make the returned error implement Headerer
// 4. we have HttpError to help with custom Headerer and StatusCoder implementation
type ErrorTranslator interface {
Translate(ctx context.Context, err error) error
}
// ErrorTranslateFunc is similar to ErrorTranslator in function format. Mostly used for selective error translation
// registration (ErrorHandlerMapping). Same implementing rules applies
type ErrorTranslateFunc func(ctx context.Context, err error) error
func (fn ErrorTranslateFunc) Translate(ctx context.Context, err error) error {
return fn(ctx, err)
}
/*********************************
Mappings
*********************************/
// Controller is usually implemented by user-domain types to provide a group of HTTP handling logics.
// Each Controller provides a list of Mapping that defines how HTTP requests should be handled.
// See Mapping
type Controller interface {
Mappings() []Mapping
}
// Mapping is generic interface for all kind of HTTP mappings.
// User-domain do not typically to implement this interface. Instead, predefined implementation and their builders
// should be used.
// See StaticMapping, RoutedMapping, MvcMapping, SimpleMapping, etc.
type Mapping interface {
Name() string
}
// StaticMapping defines static assets handling. e.g. javascripts, css, images, etc.
// See assets.New()
type StaticMapping interface {
Mapping
Path() string
StaticRoot() string
Aliases() map[string]string
AddAlias(path, filePath string) StaticMapping
}
// RoutedMapping defines dynamic HTTP handling with specific HTTP Route (path and method) and optionally a RequestMatcher as condition.
// RoutedMapping includes SimpleMapping, MvcMapping, etc.
type RoutedMapping interface {
Mapping
Group() string
Path() string
Method() string
Condition() RequestMatcher
}
// SimpleMapping endpoints that are directly implemented as HandlerFunc.
// See mapping.MappingBuilder
type SimpleMapping interface {
RoutedMapping
HandlerFunc() http.HandlerFunc
}
// MvcHandlerFunc is the generic function to be used for MvcMapping.
// See MvcMapping, rest.EndpointFunc, template.ModelViewHandlerFunc
type MvcHandlerFunc func(c context.Context, request interface{}) (response interface{}, err error)
// MvcMapping defines HTTP handling that follows MVC pattern:
// 1. The http.Request is decoded in to a request model object using MvcMapping.DecodeRequestFunc().
// 2. The request model object is processed by MvcMapping.HandlerFunc() and a response model object is returned.
// 3. The response model object is rendered into http.ResponseWriter using MvcMapping.EncodeResponseFunc().
// 4. If any steps yield error, the error is rendered into http.ResponseWriter using MvcMapping.EncodeErrorFunc()
//
// Note:
// Functions here are all weakly typed signature. User-domain developers typically should use mapping builders
// (rest.MappingBuilder, template.MappingBuilder, etc) to create concrete MvcMapping instances.
// See EndpointMapping or TemplateMapping
type MvcMapping interface {
RoutedMapping
DecodeRequestFunc() DecodeRequestFunc
EncodeResponseFunc() EncodeResponseFunc
EncodeErrorFunc() EncodeErrorFunc
HandlerFunc() MvcHandlerFunc
}
// EndpointMapping defines REST API mapping.
// REST API is usually implemented by Controller and accept/produce JSON objects
// See rest.MappingBuilder
type EndpointMapping MvcMapping
// TemplateMapping defines templated MVC mapping. e.g. html templates
// Templated MVC is usually implemented by Controller and produce a template and model for dynamic html generation.
// See template.MappingBuilder
type TemplateMapping MvcMapping
// MiddlewareMapping defines middlewares that applies to all or selected set (via Matcher and Condition) of requests.
// Middlewares are often used for task like security, pre/post processing request or response, metrics measurements, etc.
// See middleware.MappingBuilder
type MiddlewareMapping interface {
Mapping
Matcher() RouteMatcher
Order() int
Condition() RequestMatcher
HandlerFunc() http.HandlerFunc
}
// ErrorTranslateMapping defines how errors should be handled before it's rendered into http.ResponseWriter.
// See weberror.MappingBuilder
type ErrorTranslateMapping interface {
Mapping
Matcher() RouteMatcher
Order() int
Condition() RequestMatcher
TranslateFunc() ErrorTranslateFunc
}
/*********************************
Routing Matchers
*********************************/
// Route contains information needed for registering handler func in gin.Engine
type Route struct {
Method string
Path string
Group string
}
// RouteMatcher is a typed ChainableMatcher that accept *Route or Route
type RouteMatcher interface {
matcher.ChainableMatcher
}
// RequestMatcher is a typed ChainableMatcher that accept *http.Request or http.Request
type RequestMatcher interface {
matcher.ChainableMatcher
}
// NormalizedPath removes path parameter name from path.
// path "/path/with/:param" is effectively same as "path/with/:other_param_name"
func NormalizedPath(path string) string {
return pathParamPattern.ReplaceAllString(path, "/:var")
}
/*********************************
SimpleMapping
*********************************/
// simpleMapping implements SimpleMapping
type simpleMapping struct {
name string
group string
path string
method string
condition RequestMatcher
handlerFunc http.HandlerFunc
}
// NewSimpleMapping create a SimpleMapping.
// It's recommended to use mapping.MappingBuilder instead of this function:
// e.g.
// <code>
// mapping.Post("/path/to/api").HandlerFunc(func...).Build()
// </code>
func NewSimpleMapping(name, group, path, method string, condition RequestMatcher, handlerFunc http.HandlerFunc) SimpleMapping {
return &simpleMapping{
name: name,
group: group,
path: path,
method: method,
condition: condition,
handlerFunc: handlerFunc,
}
}
func (g simpleMapping) HandlerFunc() http.HandlerFunc {
return g.handlerFunc
}
func (g simpleMapping) Condition() RequestMatcher {
return g.condition
}
func (g simpleMapping) Method() string {
return g.method
}
func (g simpleMapping) Group() string {
return g.group
}
func (g simpleMapping) Path() string {
return g.path
}
func (g simpleMapping) Name() string {
return g.name
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package cors
import (
"context"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/rs/cors"
"time"
)
// Customizer implements web.Customizer
type Customizer struct {
properties CorsProperties
}
func newCustomizer(properties CorsProperties) web.Customizer {
return &Customizer{
properties: properties,
}
}
func (c *Customizer) Customize(_ context.Context, r *web.Registrar) (err error) {
if !c.properties.Enabled {
return
}
mw := New(cors.Options{
AllowedOrigins: c.properties.AllowedOrigins(),
AllowedMethods: c.properties.AllowedMethods(),
AllowedHeaders: c.properties.AllowedHeaders(),
ExposedHeaders: c.properties.ExposedHeaders(),
MaxAge: int(time.Duration(c.properties.MaxAge).Seconds()),
AllowCredentials: c.properties.AllowCredentials,
OptionsPassthrough: false,
//Debug: true,
})
err = r.AddGlobalMiddlewares(mw)
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package cors
import (
"github.com/gin-gonic/gin"
"github.com/rs/cors"
"net/http"
)
// Options is a configuration container to setup the CORS middleware.
type Options = cors.Options
type corsWrapper struct {
*cors.Cors
optionPassthrough bool
}
// build transforms wrapped cors.Cors handler into Gin middleware.
func (c corsWrapper) build() gin.HandlerFunc {
return func(ctx *gin.Context) {
c.HandlerFunc(ctx.Writer, ctx.Request)
if !c.optionPassthrough &&
ctx.Request.Method == http.MethodOptions &&
ctx.GetHeader("Access-Control-Request-Method") != "" {
// Abort processing next Gin middlewares.
ctx.AbortWithStatus(http.StatusOK)
}
}
}
// New creates a new CORS Gin middleware with the provided options.
func New(options Options) gin.HandlerFunc {
return corsWrapper{cors.New(options), options.OptionsPassthrough}.build()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package cors
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"net/http"
"strings"
"time"
)
const (
CorsPropertiesPrefix = "security.cors"
listSeparator = ","
)
var (
allMethods = []string{
http.MethodGet, http.MethodHead, http.MethodPost,
http.MethodPut, http.MethodPatch, http.MethodDelete,
//http.MethodConnect, http.MethodOptions, http.MethodTrace,
}
)
type CorsProperties struct {
Enabled bool `json:"enabled"`
// Comma-separated list of origins to allow. '*' allows all origins. Default to '*'
AllowedOriginsStr string `json:"allowed-origins"`
// Comma-separated list of methods to allow. '*' allows all methods. Default to '*'
AllowedMethodsStr string `json:"allowed-methods"`
// Comma-separated list of headers to allow in a request. '*' allows all headers. Default to '*'
AllowedHeadersStr string `json:"allowed-headers"`
// Comma-separated list of headers to include in a response.
ExposedHeadersStr string `json:"exposed-headers"`
// Whether credentials are supported. When not set, credentials are not supported.
AllowCredentials bool `json:"allow-credentials"`
// How long the response from a pre-flight request can be cached by clients.
// If a duration suffix is not specified, seconds will be used.
MaxAge utils.Duration `json:"max-age"`
}
// NewCorsProperties create a ServerProperties with default values
func NewCorsProperties() *CorsProperties {
return &CorsProperties{
Enabled: false,
AllowedOriginsStr: "*",
AllowedMethodsStr: "*",
AllowedHeadersStr: "*",
ExposedHeadersStr: "",
AllowCredentials: false,
MaxAge: utils.Duration(1800 * time.Second),
}
}
func (p CorsProperties) AllowedOrigins() []string {
return splitAndTrim(p.AllowedOriginsStr)
}
func (p CorsProperties) AllowedMethods() []string {
list := splitAndTrim(p.AllowedMethodsStr)
for _, v := range list {
if v == "*" {
return allMethods
}
}
return list
}
func (p CorsProperties) AllowedHeaders() []string {
return splitAndTrim(p.AllowedHeadersStr)
}
func (p CorsProperties) ExposedHeaders() []string {
return splitAndTrim(p.ExposedHeadersStr)
}
//BindCorsProperties create and bind a ServerProperties using default prefix
func BindCorsProperties(ctx *bootstrap.ApplicationContext) CorsProperties {
props := NewCorsProperties()
if err := ctx.Config().Bind(props, CorsPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind CorsProperties"))
}
return *props
}
func splitAndTrim(s string) []string {
list := strings.Split(s, listSeparator)
for i, v := range list {
list[i] = strings.TrimSpace(v)
}
return list
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/gin-gonic/gin"
"net/http"
)
type RequestPreProcessorName string
type RequestPreProcessor interface {
Process(r *http.Request) error
Name() RequestPreProcessorName
}
type Engine struct {
*gin.Engine
requestPreProcessor []RequestPreProcessor
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for _, p := range e.requestPreProcessor {
err := p.Process(r)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = fmt.Fprint(w, "Internal error with request cache")
return
}
}
e.Engine.ServeHTTP(w, r)
}
func (e *Engine) addRequestPreProcessor(p RequestPreProcessor) {
e.requestPreProcessor = append(e.requestPreProcessor, p)
}
func NewEngine() *Engine {
if bootstrap.DebugEnabled() {
gin.SetMode(gin.DebugMode)
} else {
gin.SetMode(gin.ReleaseMode)
}
eng := &Engine{
Engine: gin.New(),
}
return eng
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"encoding/json"
"fmt"
"github.com/go-playground/validator/v10"
"net/http"
)
const (
templateValidationFieldError = "validation failed on '%s' with criteria '%s'"
)
/**************************
Generic Http Error
***************************/
type HttpErrorResponse struct {
StatusText string `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Details map[string]string`json:"details,omitempty"`
}
// HttpError implements error, json.Marshaler, StatusCoder, Headerer
// Note: Do not use HttpError as a map key, because is is not hashable (it contains http.Header which is a map)
type HttpError struct {
error
SC int
H http.Header
}
// MarshalJSON implements json.Marshaler
func (e HttpError) MarshalJSON() ([]byte, error) {
//nolint:errorlint
if original,ok := e.error.(json.Marshaler); ok {
return original.MarshalJSON()
}
err := &HttpErrorResponse{
StatusText: http.StatusText(e.StatusCode()),
Message: e.Error(),
}
return json.Marshal(err)
}
// StatusCode implements StatusCoder
func (e HttpError) StatusCode() int {
//nolint:errorlint
if original,ok := e.error.(StatusCoder); ok {
return original.StatusCode()
} else if e.SC == 0 {
return http.StatusInternalServerError
} else {
return e.SC
}
}
// Headers implements Headerer
func (e HttpError) Headers() http.Header {
//nolint:errorlint
if original,ok := e.error.(Headerer); ok {
return original.Headers()
}
return e.H
}
/**************************
BadRequest Errors
***************************/
type BadRequestError struct {
error
}
// StatusCode implements StatusCoder
func (_ BadRequestError) StatusCode() int {
return http.StatusBadRequest
}
func (e BadRequestError) Unwrap() error {
return e.error
}
type BindingError struct {
error
}
// StatusCode implements StatusCoder
func (_ BindingError) StatusCode() int {
return http.StatusBadRequest
}
func (e BindingError) Unwrap() error {
return e.error
}
type ValidationErrors struct {
validator.ValidationErrors
}
func (e ValidationErrors) Unwrap() error {
return e.ValidationErrors
}
// MarshalJSON implements json.Marshaler
func (e ValidationErrors) MarshalJSON() ([]byte, error) {
err := &HttpErrorResponse{
StatusText: http.StatusText(e.StatusCode()),
Message: "validation failed",
Details: make(map[string]string, len(e.ValidationErrors)),
}
for _, obj := range e.ValidationErrors {
fe := obj.(validator.FieldError)
err.Details[fe.Namespace()] = fmt.Sprintf(templateValidationFieldError, fe.Field(), fe.Tag())
}
return json.Marshal(err)
}
// StatusCode implements StatusCoder
func (_ ValidationErrors) StatusCode() int {
return http.StatusBadRequest
}
/*****************************
Constructor Functions
******************************/
func NewHttpError(sc int, err error, headers ...http.Header) error {
var h http.Header
if len(headers) != 0 {
h = make(http.Header)
for _,toMerge := range headers {
mergeHeaders(h, toMerge)
}
}
return HttpError{error: err, SC: sc, H: h}
}
func NewBadRequestError(err error) error {
return BadRequestError{error: err}
}
func NewBindingError(e error) error {
return BindingError{error: e}
}
func mergeHeaders(src http.Header, toMerge http.Header) {
for k, values := range toMerge {
for _, v := range values {
src.Add(k, v)
}
}
}
/*****************************
Privates
******************************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"context"
"encoding/json"
"fmt"
"github.com/go-playground/validator/v10"
"net/http"
)
/*************************
ErrorHandlerMapping
*************************/
type errorTranslateMapping struct {
name string
order int
matcher RouteMatcher
condition RequestMatcher
translateFunc ErrorTranslateFunc
}
func NewErrorTranslateMapping(name string, order int, matcher RouteMatcher, cond RequestMatcher, translateFunc ErrorTranslateFunc) ErrorTranslateMapping {
return &errorTranslateMapping{
name: name,
matcher: matcher,
order: order,
condition: cond,
translateFunc: translateFunc,
}
}
func (m errorTranslateMapping) Name() string {
return m.name
}
func (m errorTranslateMapping) Matcher() RouteMatcher {
return m.matcher
}
func (m errorTranslateMapping) Order() int {
return m.order
}
func (m errorTranslateMapping) Condition() RequestMatcher {
return m.condition
}
func (m errorTranslateMapping) TranslateFunc() ErrorTranslateFunc {
return m.translateFunc
}
/*************************
Error Translation
*************************/
func newErrorEncoder(encoder EncodeErrorFunc, translators ...ErrorTranslator) EncodeErrorFunc {
return func(ctx context.Context, err error, rw http.ResponseWriter) {
for _, t := range translators {
err = t.Translate(ctx, err)
}
encoder(ctx, err, rw)
}
}
type mappedErrorTranslator struct {
order int
condition RequestMatcher
translateFunc ErrorTranslateFunc
}
func (t mappedErrorTranslator) Order() int {
return t.order
}
func (t mappedErrorTranslator) Translate(ctx context.Context, err error) error {
if t.condition != nil {
if ginCtx := GinContext(ctx); ginCtx != nil {
if matched, e := t.condition.MatchesWithContext(ctx, ginCtx.Request); e != nil || !matched {
return err
}
}
}
return t.translateFunc(ctx, err)
}
func newMappedErrorTranslator(m ErrorTranslateMapping) *mappedErrorTranslator {
return &mappedErrorTranslator{
order: m.Order(),
condition: m.Condition(),
translateFunc: m.TranslateFunc(),
}
}
type defaultErrorTranslator struct{}
func (i defaultErrorTranslator) Translate(_ context.Context, err error) error {
//nolint:errorlint
switch e := err.(type) {
case validator.ValidationErrors:
return ValidationErrors{e}
case StatusCoder, HttpError:
return err
default:
return HttpError{error: err, SC: http.StatusInternalServerError}
}
}
func newDefaultErrorTranslator() defaultErrorTranslator {
return defaultErrorTranslator{}
}
/*****************************
Error Encoder
******************************/
func JsonErrorEncoder() EncodeErrorFunc {
return jsonErrorEncoder
}
//nolint:errorlint
func jsonErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
// body
if _, ok := err.(json.Marshaler); !ok {
err = NewHttpError(0, err)
}
body, e := json.Marshal(err)
if e != nil {
body = []byte(fmt.Sprintf(`{"error":"%s"}`, err.Error()))
}
// headers
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if headerer, ok := err.(Headerer); ok {
for k, values := range headerer.Headers() {
for _, v := range values {
w.Header().Add(k, v)
}
}
}
// status code
code := http.StatusInternalServerError
if sc, ok := err.(StatusCoder); ok {
code = sc.StatusCode()
}
// write response
w.WriteHeader(code)
_, _ = w.Write(body)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"errors"
"github.com/bmatcuk/doublestar/v4"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"io/fs"
"os"
"path"
"path/filepath"
"strings"
)
const (
DirFSAllowListDirectory DirFSOption = 1 << iota
//...
)
type DirFSOption int64
// dirFS implements fs.FS and fs.GlobFS. It is similar to http.Dir, but support fs.FS and allow option to not list directory contents
type dirFS struct {
dir string
fs fs.FS
opts DirFSOption
}
func NewOSDirFS(dir string, opts ...DirFSOption) fs.FS {
return NewDirFS("", os.DirFS(dir), opts...)
}
func NewDirFS(dir string, fsys fs.FS, opts ...DirFSOption) fs.FS {
options := DirFSOption(0)
for _, opt := range opts {
options = options | opt
}
return &dirFS{
dir: dir,
fs: fsys,
opts: options,
}
}
func (f *dirFS) Open(name string) (fs.File, error) {
if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
return nil, errors.New("invalid character in file path")
}
dir := f.dir
if dir == "" {
dir = "."
}
fullName := filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name)))
file, e := f.fs.Open(fullName)
if e != nil {
return nil, f.translateError(e, fullName)
}
// apply options
if !f.hasOption(DirFSAllowListDirectory) {
if stat, e := file.Stat(); e != nil {
return nil, f.translateError(e, fullName)
} else if stat.IsDir() {
return nil, fs.ErrNotExist
}
}
return file, nil
}
func (f *dirFS) Glob(pattern string) (ret []string, err error) {
return doublestar.Glob(f.fs, pattern)
}
func (f *dirFS) hasOption(opt DirFSOption) bool {
return f.opts & opt != 0
}
// translateError maps the provided non-nil error from opening name
// to a possibly better non-nil error. In particular, it turns OS-specific errors
// about opening files in non-directories into fs.ErrNotExist. see http.mapDirOpenError
func (f *dirFS) translateError(err error, name string) error {
if err == fs.ErrNotExist || err == fs.ErrPermission {
return err
}
parts := strings.Split(name, string(filepath.Separator))
for i := range parts {
if parts[i] == "" {
continue
}
fi, e := os.Stat(strings.Join(parts[:i+1], string(filepath.Separator)))
if e != nil {
return e
} else if fi != nil && !fi.IsDir() {
return fs.ErrNotExist
}
}
return err
}
// orderedFS implements fs.FS and order.Ordered
type orderedFS struct {
fs.FS
order int
}
// OrderedFS returns a fs.FS that also implements order.Ordered
// if the given fs.FS is already implement the order.Ordered, "defaultOrder" is ignored
func OrderedFS(fsys fs.FS, defaultOrder int) fs.FS {
return &orderedFS{
FS: fsys,
order: defaultOrder,
}
}
func (f orderedFS) Order() int {
return f.order
}
// MergedFS implements fs.FS and fs.GlobFS
type MergedFS struct {
srcFS []fs.FS
}
func NewMergedFS(atLeastOne fs.FS, fs ...fs.FS) *MergedFS {
src := append(fs, atLeastOne)
order.SortStable(src, order.OrderedFirstCompare)
m := &MergedFS{
srcFS: src,
}
return m
}
func (m *MergedFS) Open(name string) (fs.File, error) {
for _, f := range m.srcFS {
if file, e := f.Open(name); e == nil {
return file, nil
}
}
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
func (m *MergedFS) Glob(pattern string) (ret []string, err error) {
// loop through all FS sources in reversed order
for i := len(m.srcFS) - 1; i >= 0; i-- {
paths, e := doublestar.Glob(m.srcFS[i], pattern)
if e != nil {
return nil, e
}
ret = append(ret, paths...)
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"fmt"
"go.uber.org/fx"
"reflect"
"runtime"
)
var (
typeController = reflect.TypeOf(func(Controller) { /* empty */ }).In(0)
typeCustomizer = reflect.TypeOf(func(Customizer) { /* empty */ }).In(0)
typeErrorTranslator = reflect.TypeOf(func(ErrorTranslator) { /* empty */ }).In(0)
typeError = reflect.TypeOf((*error)(nil)).Elem()
typeFxOut = reflect.TypeOf(fx.Out{})
typeFxIn = reflect.TypeOf(fx.In{})
)
func FxControllerProviders(targets ...interface{}) fx.Option {
providers := groupedProviders(FxGroupControllers, typeController, targets)
return fx.Provide(providers...)
}
func FxCustomizerProviders(targets ...interface{}) fx.Option {
providers := groupedProviders(FxGroupCustomizers, typeCustomizer, targets)
return fx.Provide(providers...)
}
func FxErrorTranslatorProviders(targets ...interface{}) fx.Option {
providers := groupedProviders(FxGroupErrorTranslator, typeErrorTranslator, targets)
return fx.Provide(providers...)
}
// groupedProviders construct a slice of []fx.Annotated with given "group". Basic return type checking
// is performed against expected "provideType"
func groupedProviders(group string, interfaceType reflect.Type, targets []interface{}) []interface{} {
ret := make([]interface{}, len(targets))
for i, target := range targets {
shouldAnnotate, numOutput, e := validateFxProviderTarget(interfaceType, target)
if e != nil {
panic(e)
}
if shouldAnnotate {
types := make([]interface{}, numOutput)
tags := make([]string, numOutput)
for i := 0; i < numOutput; i++ {
//The fx.As(interfaces ...interface{}) expects pointer to interface, i.e. fx.As(new(io.Writer)).
// So if we want to annotate something as Controller, we need to initialize a *Controller variable to use in fx.As.
// Here interfaceType is Controller,
// so reflect.New will give us a Value variable representing a pointer to zero value of Controller, in other words, a *Controller.
// Then the Interface() call goes from Value to interface{} so that we can use it in fx.As(interfaces ...interface{})
types[i] = reflect.New(interfaceType).Interface()
tags[i] = fmt.Sprintf("group:\"%s\"", group)
}
annotation := fx.As(types...)
ret[i] = fx.Annotate(target, annotation, fx.ResultTags(tags...))
} else {
ret[i] = fx.Annotated{
Group: group,
Target: target,
}
}
}
return ret
}
// best effort to valid target provider
func validateFxProviderTarget(interfaceType reflect.Type, target interface{}) (shouldAnnotate bool, effectiveNumOut int, err error) {
t := reflect.TypeOf(target)
if t.Kind() != reflect.Func {
panic(fmt.Errorf("fx annotated provider target must be a function, but got %T", target))
}
// 1. the return types must implements Controller except the last return value
// 1.a if the return type is not Controller, it must be suitable for annotation (i.e. it can't use fx.In)
// 2. the last return value can be error
isOutputValid := true
for i := 0; i < t.NumOut(); i++ {
rt := t.Out(i)
if !rt.Implements(interfaceType) {
// if it's the last return value
if i > 0 && i == t.NumOut()-1 {
if !isExactType(typeError, rt) {
isOutputValid = false
break
}
} else { // every return item other than the last one must implement the expected interface
isOutputValid = false
break
}
} else {
if !isExactType(interfaceType, rt) {
shouldAnnotate = true
}
effectiveNumOut++
}
}
isInputValid := true
// check if we can actually annotate
if shouldAnnotate {
for i := 0; i < t.NumIn(); i++ {
it := t.In(i)
if it.Kind() == reflect.Struct {
for j := 0; j < it.NumField(); j++ {
// if the input struct embeds fx.In, then we won't be able to annotate, so it's invalid
if isExactType(it.Field(j).Type, typeFxIn) {
isInputValid = false
break
}
}
}
}
}
if !isOutputValid {
shouldAnnotate = false
effectiveNumOut = 0
err = fmt.Errorf("Web registable provider must return implementation of type %s.%s, but got %v",
interfaceType.PkgPath(), interfaceType.Name(), describeFunc(target))
} else if !isInputValid {
shouldAnnotate = false
effectiveNumOut = 0
err = fmt.Errorf("If web registable provider does not return exact type %s.%s, it must not use Fx.In, but got %v",
interfaceType.PkgPath(), interfaceType.Name(), describeFunc(target))
} else {
err = nil
}
return
}
func describeFunc(f interface{}) string {
pc := reflect.ValueOf(f).Pointer()
pFunc := runtime.FuncForPC(pc)
if pFunc == nil {
return "unknown function"
}
return pFunc.Name()
}
func isExactType(expected reflect.Type, t reflect.Type) bool {
return t.PkgPath() == expected.PkgPath() && t.Name() == expected.Name()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/gin-gonic/gin"
"net/http"
"path"
)
type contextPathCtxKey struct {}
// Interfaces, functions, HandlerFunc wrappers and gin middlewares that make sure *gin.Context available in endpoints and
// context properly propagated in Request
// SimpleGinMapping is a SimpleMapping that supported by gin.HandlerFunc
// See mapping.MappingBuilder
type SimpleGinMapping interface {
SimpleMapping
GinHandlerFunc() gin.HandlerFunc
}
// MiddlewareGinMapping is a MiddlewareMapping that supported by gin.HandlerFunc
// See middleware.MappingBuilder
type MiddlewareGinMapping interface {
MiddlewareMapping
GinHandlerFunc() gin.HandlerFunc
}
/**************************
Public
**************************/
// GinContext returns *gin.Context which either contained in the context or is the given context itself
func GinContext(ctx context.Context) *gin.Context {
if ginCtx, ok := ctx.(*gin.Context); ok {
return ginCtx
}
if ginCtx, ok := ctx.Value(gin.ContextKey).(*gin.Context); ok {
return ginCtx
}
return nil
}
// MustGinContext returns *gin.Context like GinContext but panic if not found
func MustGinContext(ctx context.Context) *gin.Context {
if gc := GinContext(ctx); gc != nil {
return gc
}
panic(fmt.Sprintf("gin.Context is not found in given context %v", ctx))
}
// HttpRequest returns *http.Request associated with given context
func HttpRequest(ctx context.Context) *http.Request {
if gc := GinContext(ctx); gc != nil {
return gc.Request
}
return nil
}
// MustHttpRequest returns *http.Request associated with given context, panic if not found
func MustHttpRequest(ctx context.Context) *http.Request {
return MustGinContext(ctx).Request
}
// ContextPath returns the "server.context-path" from properties with leading "/".
// This function returns empty string if context-path is root or not set
func ContextPath(ctx context.Context) string {
ctxPath, _ := ctx.Value(contextPathCtxKey{}).(string)
return ctxPath
}
// SetKV set a kv pair to given context if:
// - The context is a utils.MutableContext
// - The context has utils.MutableContext as parent/ancestors
// - The context contains *gin.Context
// The value then can be obtained via context.Context.Value(key)
//
// This function uses utils.FindMutableContext and GinContext() to find KV storage. Then store KV pair using following rule:
// - If utils.FindMutableContext returns non-nil, utils.MutableContext interface is used
// - If utils.FindMutableContext returns nil but *gin.Context is found:
// + If the key is string, KV pair is set as-is
// + Otherwise, uses fmt.Sprintf(`%v`, key) as key and set KV pair
// - If none of conditions met, this function does nothing
func SetKV(ctx context.Context, key interface{}, value interface{}) {
if mc := utils.FindMutableContext(ctx); mc != nil {
mc.Set(key, value)
}
if gc := GinContext(ctx); gc != nil {
switch k := key.(type) {
case string:
gc.Set(k, value)
default:
gc.Set(fmt.Sprintf("%v", key), value)
}
}
}
/**************************
Customizers
**************************/
// PriorityGinContextCustomizer implements Customizer and order.PriorityOrdered
type PriorityGinContextCustomizer struct {
properties *ServerProperties
}
func NewPriorityGinContextCustomizer(properties *ServerProperties) *PriorityGinContextCustomizer {
return &PriorityGinContextCustomizer{
properties: properties,
}
}
func (c PriorityGinContextCustomizer) PriorityOrder() int {
// medium precedence, makes this customizer before any non-priority-ordered customizers
return 0
}
//nolint:contextcheck // context is not relevant here - should pass the context parameter
func (c PriorityGinContextCustomizer) Customize(_ context.Context, r *Registrar) error {
if e := r.AddGlobalMiddlewares(GinContextMerger()); e != nil {
return e
}
if e := r.AddGlobalMiddlewares(PropertiesAware(c.properties)); e != nil {
return e
}
return r.AddEngineOptions(func(engine *Engine) {
engine.ContextWithFallback = true
})
}
/**************************
Handler Func
**************************/
// PropertiesAware is a Gin middleware mandatory for all mappings.
// It save necessary properties into request's context. e.g. context-path
// The saved properties can be used in many components/utilities.
func PropertiesAware(props *ServerProperties) gin.HandlerFunc {
return func(gc *gin.Context) {
if mc := utils.FindMutableContext(gc); mc != nil {
ctxPath := path.Clean("/" + props.ContextPath)
if ctxPath != "/" && ctxPath != "." {
mc.Set(contextPathCtxKey{}, ctxPath)
}
}
}
}
// GinContextMerger is a Gin middleware that merge Request.Context() with gin.Context,
// allowing values in gin.Context also available via Request.Context().Value().
// This middleware is mandatory for all mappings.
// Note: as of Gin 1.8.0, if we set gin.Engine.ContextWithFallback to true. This makes gin.Context fully integrated
// with its underling Request.Context(). The side effect of this is gin.Context.Value() is also calling
// Request.Context().Value(), which cause stack overflow on non-existing keys.
//
// To break this loop, we use different version of utils.ContextValuer to extract values from gin.Context(),
// without using gin.Context.Value() function.
func GinContextMerger() gin.HandlerFunc {
return func(gc *gin.Context) {
ctx := utils.MakeMutableContext(gc.Request.Context(), ginContextValuer(gc))
// Note, this is optional since Gin 1.8.0. We are doing this simply for performance
ctx.Set(gin.ContextKey, gc)
gc.Request = gc.Request.WithContext(ctx)
}
}
// NewHttpGinHandlerFunc integrate http.HandlerFunc with GIN handler
func NewHttpGinHandlerFunc(handlerFunc http.HandlerFunc) gin.HandlerFunc {
if handlerFunc == nil {
panic(fmt.Errorf("cannot wrap a nil hanlder"))
}
handler := func(c *gin.Context) {
c = preProcessGinContext(c)
handlerFunc(c.Writer, c.Request)
}
return handler
}
func preProcessGinContext(gc *gin.Context) *gin.Context {
// because of GinContextMerger is mandatory middleware for all mappings, we are sure gc.Request.Context() contains gin.Context.
// So we only need to make sure it's also mutable
rc := gc.Request.Context()
ctx := utils.MakeMutableContext(rc)
if ctx != rc {
gc.Request = gc.Request.WithContext(ctx)
}
// note, we could also make a copy of gin context in case we want to use it out of request scope
// but currently, we don't have such requirement
return gc
}
/**************************
helpers
**************************/
func ginContextValuer(gc *gin.Context) func(key interface{}) interface{} {
return func(key interface{}) interface{} {
switch strKey, _ := key.(string); strKey {
case gin.ContextKey:
return gc
default:
v, _ := gc.Get(strKey)
return v
}
}
}
/*********************************
SimpleGinMapping
*********************************/
// simpleGinMapping implements SimpleGinMapping
type simpleGinMapping struct {
simpleMapping
handlerFunc gin.HandlerFunc
}
// NewSimpleGinMapping create a SimpleGinMapping.
// It's recommended to use mapping.MappingBuilder instead of this function:
// e.g.
// <code>
// mapping.Post("/path/to/api").HandlerFunc(func...).Build()
// </code>
func NewSimpleGinMapping(name, group, path, method string, condition RequestMatcher, handlerFunc gin.HandlerFunc) SimpleGinMapping {
return &simpleGinMapping{
simpleMapping: *NewSimpleMapping(name, group, path, method, condition, nil).(*simpleMapping),
handlerFunc: handlerFunc,
}
}
func (m simpleGinMapping) GinHandlerFunc() gin.HandlerFunc {
if m.handlerFunc != nil {
return m.handlerFunc
}
if m.simpleMapping.handlerFunc != nil {
return NewHttpGinHandlerFunc(m.simpleMapping.handlerFunc)
}
return nil
}
/*********************************
MiddlewareGinMapping
*********************************/
// middlewareGinMapping implements MiddlewareGinMapping
type middlewareGinMapping struct {
middlewareMapping
handlerFunc gin.HandlerFunc
}
// NewMiddlewareGinMapping create a MiddlewareGinMapping with gin.HandlerFunc
// It's recommended to use middleware.MappingBuilder instead of this function:
// e.g.
// <code>
// middleware.NewBuilder("my-auth").Order(-10).Use(func...).Build()
// </code>
func NewMiddlewareGinMapping(name string, order int, matcher RouteMatcher, cond RequestMatcher, handlerFunc gin.HandlerFunc) MiddlewareGinMapping {
return &middlewareGinMapping{
middlewareMapping: *NewMiddlewareMapping(name, order, matcher, cond, nil).(*middlewareMapping),
handlerFunc: handlerFunc,
}
}
func (m middlewareGinMapping) GinHandlerFunc() gin.HandlerFunc {
if m.handlerFunc != nil {
return m.handlerFunc
}
if m.middlewareMapping.handlerFunc != nil {
return NewHttpGinHandlerFunc(m.middlewareMapping.handlerFunc)
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"context"
"encoding"
"encoding/json"
"github.com/gin-gonic/gin"
"net/http"
)
/**************************
Support GenHandling
**************************/
// GinErrorHandlingCustomizer implements Customizer
type GinErrorHandlingCustomizer struct {
}
func NewGinErrorHandlingCustomizer() *GinErrorHandlingCustomizer {
return &GinErrorHandlingCustomizer{}
}
func (c GinErrorHandlingCustomizer) Customize(ctx context.Context, r *Registrar) error {
return r.AddGlobalMiddlewares(DefaultErrorHandling())
}
// DefaultErrorHandling implement error handling logics at last resort, in case errors are not properly handled downstream
func DefaultErrorHandling() gin.HandlerFunc {
return func(gc *gin.Context) {
gc.Next()
if gc.Writer.Written() || len(gc.Errors) == 0 {
return
}
// find first error that implements StatusCoder
// if not found, use the first one
err := gc.Errors[0].Err
for _, e := range gc.Errors {
//nolint:errorlint
if _,ok := e.Err.(StatusCoder); !ok {
err = e.Err
break
}
}
handleError(gc.Request.Context(), err, gc.Writer)
}
}
//nolint:errorlint
func handleError(_ context.Context, err error, rw http.ResponseWriter) {
// body
contentType, body := "text/plain; charset=utf-8", []byte{}
// prefer JSON if available
if marshaler, ok := err.(json.Marshaler); len(body) == 0 && ok {
if jsonBody, e := marshaler.MarshalJSON(); e == nil {
contentType, body = "application/json; charset=utf-8", jsonBody
}
}
// then try text
if marshaler, ok := err.(encoding.TextMarshaler); len(body) == 0 && ok {
if textBody, e := marshaler.MarshalText(); e == nil {
body = textBody
}
}
if len(body) == 0 {
body = []byte(err.Error())
}
// header
rw.Header().Set("Content-Type", contentType)
if headerer, ok := err.(Headerer); ok {
for k, values := range headerer.Headers() {
for _, v := range values {
rw.Header().Add(k, v)
}
}
}
// status code
code := http.StatusInternalServerError
if sc, ok := err.(StatusCoder); ok {
code = sc.StatusCode()
}
rw.WriteHeader(code)
_, _ = rw.Write(body)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
)
type ginRequestRewriter struct {
engine *gin.Engine
}
func newGinRequestRewriter(engine *gin.Engine) RequestRewriter {
return &ginRequestRewriter{
engine: engine,
}
}
// HandleRewrite Caution, you could loop yourself to death
func (rw ginRequestRewriter) HandleRewrite(r *http.Request) error {
gc := GinContext(r.Context())
if gc == nil {
return fmt.Errorf("the request is not linked to a gin Context. Please make sure this is the right RequestRewriter to use")
}
rw.engine.HandleContext(gc)
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"github.com/gin-gonic/gin"
"io/fs"
"net/http"
"strings"
)
var (
predefinedAliases = map[string]string {
"index": "index.html",
}
predefinedExtensions = []string{
".gz",
}
gzipContentTypeMapping = map[string]string {
".js.gz": "text/javascript",
".css.gz": "text/css",
".html.gz": "text/html",
}
)
type ginStaticAssetsHandler struct {
rewriter RequestRewriter
fsys fs.FS
aliases map[string]string
}
func (h ginStaticAssetsHandler) FilenameRewriteHandlerFunc() gin.HandlerFunc {
//prefix := h.calculateStripPrefix(basePath, relativePath)
return func(gc *gin.Context) {
file := gc.Param("filepath")
if h.canRead(file) {
return
}
// try aliases
if handled := h.tryAliases(gc, h.aliases, file); handled {
return
}
if handled := h.tryAliases(gc, predefinedAliases, file); handled {
return
}
// try extensions
h.tryExtensions(gc, predefinedExtensions, file)
}
}
func (h ginStaticAssetsHandler) PreCompressedGzipAsset() gin.HandlerFunc {
return func(gc *gin.Context) {
if !h.isGzipAsset(gc.Request) {
return
}
gc.Header("Content-Encoding", "gzip")
gc.Header("Vary", "Accept-Encoding")
// write specific content-type if extension is recognized.
// this is required for some browsers, e.g. Firefox
for k, v := range gzipContentTypeMapping {
if strings.HasSuffix(gc.Request.URL.Path, k) {
gc.Header("Content-Type", v)
break
}
}
}
}
func (h ginStaticAssetsHandler) isGzipAsset(req *http.Request) bool {
if !strings.HasSuffix(req.URL.Path, ".gz") {
return false
}
if !strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") ||
strings.Contains(req.Header.Get("Connection"), "Upgrade") ||
strings.Contains(req.Header.Get("Content-Type"), "text/event-stream") {
return false
}
return true
}
func (h ginStaticAssetsHandler) canRead(filePath string) bool {
f, e := h.fsys.Open(filePath)
defer func() {
if f != nil {
_ = f.Close()
}
}()
return e == nil
}
func (h ginStaticAssetsHandler) tryAliases(gc *gin.Context, aliases map[string]string, file string) bool {
for k, v := range aliases {
if !strings.HasSuffix(file, k) {
continue
}
alias := h.replaceLast(file, k, v)
// to avoid infinite loop or unnecessary rewrite,
// we check if alias is same as the original file and if the alias file path exists
if alias == file || !h.canRead(alias) {
continue
}
_ = h.rewrite(gc, k, v)
return true
}
return false
}
func (h ginStaticAssetsHandler) tryExtensions(gc *gin.Context, extensions []string, file string) bool {
for _, v := range extensions {
alias := file + v
// to avoid infinite loop or unnecessary rewrite,
// we check if alias is same as the original file and if the alias file path exists
if alias == file || !h.canRead(alias) {
continue
}
_ = h.rewrite(gc, "", v)
return true
}
return false
}
func (h ginStaticAssetsHandler) rewrite(gc *gin.Context, value, rewrite string) error {
// make a url copy
u := *gc.Request.URL
u.Path = h.replaceLast(u.Path, value, rewrite)
// handle rewrite
request := gc.Request
request.URL = &u
return h.rewriter.HandleRewrite(request)
}
func (h ginStaticAssetsHandler) replaceLast(s, substr, replacement string) string {
if substr == "" {
return s + replacement
}
i := strings.LastIndex(s, substr)
if i < 0 {
return s
}
return s[:i] + replacement + s[i+len(substr):]
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package init
import (
"context"
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/cors"
webtracing "github.com/cisco-open/go-lanai/pkg/web/tracing"
"go.uber.org/fx"
)
//go:embed defaults-web.yml
var defaultConfigFS embed.FS
var Module = &bootstrap.Module{
Name: "web",
Precedence: web.MinWebPrecedence,
PriorityOptions: []fx.Option{
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(
web.BindServerProperties,
web.NewEngine,
web.NewRegistrar),
fx.Invoke(setup),
},
Modules: []*bootstrap.Module{
cors.Module, webtracing.Module,
},
}
// Use Allow service to include this module in main()
func Use() {
bootstrap.Register(Module)
}
/**************************
Provide dependencies
***************************/
/*
*************************
Setup
**************************
*/
type initDI struct {
fx.In
Registrar *web.Registrar
Properties web.ServerProperties
Controllers []web.Controller `group:"controllers"`
Customizers []web.Customizer `group:"customizers"`
ErrorTranslators []web.ErrorTranslator `group:"error_translators"`
}
func setup(lc fx.Lifecycle, di initDI) {
di.Registrar.MustRegister(web.NewLoggingCustomizer(di.Properties))
di.Registrar.MustRegister(web.NewRecoveryCustomizer())
di.Registrar.MustRegister(web.NewGinErrorHandlingCustomizer())
di.Registrar.MustRegister(di.Controllers)
di.Registrar.MustRegister(di.Customizers)
di.Registrar.MustRegister(di.ErrorTranslators)
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) (err error) {
return di.Registrar.Run(ctx)
},
OnStop: func(ctx context.Context) error {
return di.Registrar.Stop(ctx)
},
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package mvc
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"net/http"
"reflect"
)
/*****************************
Func Metadata
******************************/
const (
templateInvalidMvcHandlerFunc = "invalid MVC handler function signature: %v, but got <%v>"
errorMsgExpectFunc = "expecting a function"
errorMsgInputParams = "function should have one or two input parameters, where the first is context.Context and the second is a struct or pointer to struct"
errorMsgOutputParams = "function should have at least two output parameters, where the the last is error"
errorMsgInvalidSignature = "unable to find request or response type"
)
// mapping related
type errorInvalidMvcHandlerFunc struct {
reason error
target *reflect.Value
}
func (e *errorInvalidMvcHandlerFunc) Error() string {
return fmt.Sprintf(templateInvalidMvcHandlerFunc, e.reason.Error(), e.target.Type())
}
var (
specialTypeContext = reflect.TypeOf((*context.Context)(nil)).Elem()
specialTypeHttpRequestPtr = reflect.TypeOf(&http.Request{})
specialTypeInt = reflect.TypeOf(0)
specialTypeHttpHeader = reflect.TypeOf((*http.Header)(nil)).Elem()
specialTypeError = reflect.TypeOf((*error)(nil)).Elem()
)
// HandlerFuncValidator validate HandlerFunc signature
type HandlerFuncValidator func(f *reflect.Value) error
// HandlerFunc is a function with supported signature to handle MVC request and returns MVC response or error
// See rest.MappingBuilder and template.MappingBuilder for supported function signatures
type HandlerFunc interface{}
type param struct {
i int
t reflect.Type
}
func (p param) isValid() bool {
return p.i >= 0 && p.t != nil
}
// out parameters
type mvcOut struct {
count int
sc param
header param
response param
error param
}
// in parameters
type mvcIn struct {
count int
context param
request param
}
type Metadata struct {
function *reflect.Value
request reflect.Type
response reflect.Type
in mvcIn
out mvcOut
}
// NewFuncMetadata uses reflect to analyze the given handler function and create a Metadata.
// this function panic if given function have incorrect signature
// Caller can provide an optional validator to further validate function signature on top of default validation
func NewFuncMetadata(endpointFunc HandlerFunc, validator HandlerFuncValidator) *Metadata {
f := reflect.ValueOf(endpointFunc)
err := validateFunc(&f, validator)
if err != nil {
//fatal error
panic(err)
}
t := f.Type()
unknown := param{-1, nil}
meta := Metadata{
function: &f,
in: mvcIn{
context: unknown, request: unknown,
},
out: mvcOut{
sc: unknown, header: unknown,
response: unknown, error: unknown,
},
}
// parse input params
for i := t.NumIn() - 1; i >= 0; i-- {
switch it := t.In(i); {
case it.ConvertibleTo(specialTypeContext):
meta.in.context = param{i, it}
case !meta.in.request.isValid() && isSupportedRequestType(it):
meta.in.request = param{i, it}
meta.request = it
default:
panic(&errorInvalidMvcHandlerFunc{
reason: errors.New(fmt.Sprintf("unknown input parameters at index %v", i)),
target: &f,
})
}
meta.in.count++
}
// parse output params
for i := t.NumOut() - 1; i >= 0; i-- {
switch ot := t.Out(i); {
case ot.ConvertibleTo(specialTypeInt):
meta.out.sc = param{i, ot}
case ot.ConvertibleTo(specialTypeHttpHeader):
meta.out.header = param{i, ot}
case ot.ConvertibleTo(specialTypeError):
meta.out.error = param{i, ot}
case !meta.out.response.isValid() && isSupportedResponseType(ot):
// we allow interface and map as response
meta.out.response = param{i, ot}
meta.response = ot
default:
panic(&errorInvalidMvcHandlerFunc{
reason: errors.New(fmt.Sprintf("unknown return parameters at index %v", i)),
target: &f,
})
}
meta.out.count++
}
if meta.response == nil || meta.in.count < 1 || meta.out.count < 2 || meta.in.count > 1 && meta.request == nil {
panic(&errorInvalidMvcHandlerFunc{
reason: errors.New(errorMsgInvalidSignature),
target: &f,
})
}
return &meta
}
func (m Metadata) HandlerFunc() web.MvcHandlerFunc {
return func(c context.Context, request interface{}) (response interface{}, err error) {
// prepare input params
in := make([]reflect.Value, m.in.count)
in[m.in.context.i] = reflect.ValueOf(c)
if m.in.request.isValid() {
in[m.in.request.i] = reflect.ValueOf(request)
}
out := m.function.Call(in)
// post process output
err, _ = out[m.out.error.i].Interface().(error)
response = out[m.out.response.i].Interface()
if !m.out.sc.isValid() && !m.out.header.isValid() {
return response, err
}
// if necessary, wrap the response
wrapper := &web.Response{B: response}
if m.out.sc.isValid() {
wrapper.SC = int(out[m.out.sc.i].Int())
}
if m.out.header.isValid() {
wrapper.H, _ = out[m.out.header.i].Interface().(http.Header)
}
return wrapper, err
}
}
func validateFunc(f *reflect.Value, validator HandlerFuncValidator) (err error) {
// For now, we check function signature at runtime.
// I wish there is a way to check it at compile-time that I didn't know of
t := f.Type()
switch {
case f.Kind() != reflect.Func:
return &errorInvalidMvcHandlerFunc{
reason: errors.New(errorMsgExpectFunc),
target: f,
}
// In params validation
case t.NumIn() < 1 || t.NumIn() > 2:
fallthrough
case !t.In(0).ConvertibleTo(specialTypeContext):
fallthrough
case t.NumIn() == 2 && !isSupportedRequestType(t.In(t.NumIn()-1)):
return &errorInvalidMvcHandlerFunc{
reason: errors.New(errorMsgInputParams),
target: f,
}
// Out params validation
case t.NumOut() < 2:
fallthrough
case !t.Out(t.NumOut() - 1).ConvertibleTo(specialTypeError):
return &errorInvalidMvcHandlerFunc{
reason: errors.New(errorMsgOutputParams),
target: f,
}
}
if validator != nil {
return validator(f)
}
return nil
}
func isStructOrPtrToStruct(t reflect.Type) (ret bool) {
ret = t.Kind() == reflect.Struct
ret = ret || t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
return
}
// isHttpRequestPtr returns true if given type is *http.Request
func isHttpRequestPtr(t reflect.Type) bool {
return t == specialTypeHttpRequestPtr
}
func isSupportedRequestType(t reflect.Type) bool {
return isStructOrPtrToStruct(t)
}
func isSupportedResponseType(t reflect.Type) bool {
if isStructOrPtrToStruct(t) {
return true
}
switch t.Kind() {
case reflect.Interface:
fallthrough
case reflect.Map:
fallthrough
case reflect.String:
return true
case reflect.Slice:
fallthrough
case reflect.Array:
return t.Elem().Kind() == reflect.Uint8
default:
return false
}
}
package mvc
import (
"context"
"github.com/cisco-open/go-lanai/pkg/web"
"net/http"
"reflect"
)
/**********************************
Request Decoder
***********************************/
// GinBindingRequestDecoder is a web.DecodeRequestFunc utilizing gin.Context's binding capabilities.
// The decoder instantiate the object based on Metadata.request
func GinBindingRequestDecoder(s *Metadata) web.DecodeRequestFunc {
// No need to decode
if s.request == nil || isHttpRequestPtr(s.request) {
return func(c context.Context, r *http.Request) (request interface{}, err error) {
return r, nil
}
}
// decode request using GinBinding
return web.GinBindingRequestDecoder(func() interface{} {
return instantiateByType(s.request)
})
}
// allocate memory space of given type.
// If the given type is a pointer, the returned value is non-nil.
// Otherwise, a zero value is returned
func instantiateByType(t reflect.Type) reflect.Value {
switch t.Kind() {
case reflect.Ptr:
return reflect.New(t.Elem())
default:
return reflect.New(t).Elem()
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"github.com/gin-gonic/gin"
"io"
"net/http"
"strconv"
"strings"
"time"
)
var logger = log.New("Web")
const (
LogKeyHttp = "http"
LogKeyHttpStatus = "status"
LogKeyHttpMethod = "method"
LogKeyHttpClientIP = "clientIP"
LogKeyHttpPath = "path"
LogKeyHttpErrorMsg = "error"
LogKeyHttpBodySize = "bodySize"
)
// LoggingCustomizer implements Customizer and PostInitCustomizer
type LoggingCustomizer struct {
enabled bool
defaultLvl log.LoggingLevel
levels map[RequestMatcher]log.LoggingLevel
}
func NewLoggingCustomizer(props ServerProperties) *LoggingCustomizer {
return &LoggingCustomizer{
enabled: props.Logging.Enabled,
defaultLvl: props.Logging.DefaultLevel,
levels: initLevelMap(&props),
}
}
// NewSimpleGinLogFormatter is a convenient function that returns a simple gin.LogFormatter without request filtering
// Normally, LoggingCustomizer configures more complicated gin logging schema automatically.
// This function is provided purely for integrating with 3rd-party libraries that configures gin.Engine separately.
// e.g. KrakenD in API Gateway Service
func NewSimpleGinLogFormatter(logger log.ContextualLogger, defaultLevel log.LoggingLevel, levels map[RequestMatcher]log.LoggingLevel) gin.LogFormatter {
formatter := logFormatter{
logger: logger,
defaultLvl: defaultLevel,
levels: levels,
}
return formatter.intercept
}
func initLevelMap(props *ServerProperties) map[RequestMatcher]log.LoggingLevel {
levels := map[RequestMatcher]log.LoggingLevel{}
for _, v := range props.Logging.Levels {
pattern := props.ContextPath + v.Pattern
var m RequestMatcher
if v.Method == "" || v.Method == "*" {
m = withLoggingRequestPattern(pattern)
} else {
split := strings.Split(v.Method, " ")
methods := make([]string, 0, len(split))
for _, s := range split {
s := strings.TrimSpace(s)
if s != "" {
methods = append(methods, s)
}
}
m = withLoggingRequestPattern(pattern, methods...)
}
levels[m] = v.Level
}
return levels
}
func (c LoggingCustomizer) Customize(ctx context.Context, r *Registrar) error {
// override gin debug
gin.DefaultWriter = log.NewWriterAdapter(logger.WithContext(ctx), log.LevelDebug)
gin.DefaultErrorWriter = log.NewWriterAdapter(logger.WithContext(ctx), log.LevelWarn)
if !c.enabled {
return nil
}
// setup logger middleware
formatter := logFormatter{
defaultLvl: c.defaultLvl,
logger: logger,
levels: c.levels,
}
mw := gin.LoggerWithConfig(gin.LoggerConfig{
Formatter: formatter.intercept,
Output: io.Discard, // our logFormatter calls logger directly
})
if e := r.AddGlobalMiddlewares(mw); e != nil {
panic(e)
}
return nil
}
func (c LoggingCustomizer) PostInit(_ context.Context, _ *Registrar) error {
// release initializing context
gin.DefaultWriter = log.NewWriterAdapter(logger, log.LevelDebug)
gin.DefaultErrorWriter = log.NewWriterAdapter(logger, log.LevelDebug)
return nil
}
type logFormatter struct {
logger log.ContextualLogger
defaultLvl log.LoggingLevel
levels map[RequestMatcher]log.LoggingLevel
}
// intercept uses logger directly and return empty string.
// doing so would allow us to set key-value pairs
func (f logFormatter) intercept(params gin.LogFormatterParams) (empty string) {
logLevel := f.logLevel(params.Request)
if logLevel == log.LevelOff {
return
}
var statusColor, methodColor, resetColor string
methodLen := 7
if log.IsTerminal(f.logger) {
statusColor = fixColor(params.StatusCodeColor())
methodColor = fixColor(params.MethodColor())
resetColor = params.ResetColor()
methodLen = methodLen + len(methodColor) + len(resetColor)
}
if params.Latency > time.Minute {
params.Latency = params.Latency.Truncate(time.Minute)
}
params.ErrorMessage = strings.Trim(params.ErrorMessage, "\n")
// prepare message
method := fmt.Sprintf("%-" + strconv.Itoa(methodLen) + "s", methodColor + " "+ params.Method + " " + resetColor)
msg := fmt.Sprintf("[HTTP] %s %3d %s | %10v | %8s | %s %#v %s",
statusColor, params.StatusCode, resetColor,
params.Latency.Truncate(time.Microsecond),
formatSize(params.BodySize),
method,
params.Path,
params.ErrorMessage)
// prepare kv
ctx := utils.MakeMutableContext(params.Request.Context())
for k, v := range params.Keys {
ctx.Set(k, v)
}
httpEntry := map[string]interface{} {
LogKeyHttpStatus: params.StatusCode,
LogKeyHttpMethod: params.Method,
LogKeyHttpClientIP: params.ClientIP,
LogKeyHttpPath: params.Path,
LogKeyHttpBodySize: params.BodySize,
LogKeyHttpErrorMsg: params.ErrorMessage,
}
// do log
f.logger.WithContext(ctx).WithLevel(logLevel).WithKV(LogKeyHttp, httpEntry).Printf(msg)
return
}
func (f logFormatter) logLevel(r *http.Request) log.LoggingLevel {
for k, v := range f.levels {
if match, e := k.Matches(r); e == nil && match {
return v
}
}
return f.defaultLvl
}
const (
kb = 1024
mb = kb * kb
gb = mb * kb
)
func formatSize(n int) string {
switch {
case n < kb:
return fmt.Sprintf("%dB", n)
case n < mb:
return fmt.Sprintf("%.2fKB", float64(n) / kb)
case n < gb:
return fmt.Sprintf("%.2fMB", float64(n) / mb)
default:
return fmt.Sprintf("%.2fGB", float64(n) / gb)
}
}
func fixColor(color string) string {
if strings.Contains(color, "43") {
color = strings.Replace(color, "90;", "97;", 1)
}
return color
}
// loggingRequestMatcher implement RequestMatcher
// loggingRequestMatcher is exclusively used by logFormatter.
// The purpose of this matcher is
// 1. break cyclic package dependency
// 2. provide simple and faster matching
type loggingRequestMatcher struct {
pathMatcher matcher.StringMatcher
methods []string
description string
}
func withLoggingRequestPattern(pattern string, methods...string) *loggingRequestMatcher {
return &loggingRequestMatcher{
pathMatcher: matcher.WithPathPattern(pattern),
methods: methods,
description: fmt.Sprintf("request matches %v %s", methods, pattern),
}
}
func (m *loggingRequestMatcher) RequestMatches(_ context.Context, r *http.Request) (bool, error) {
path := r.URL.Path
match, e := m.pathMatcher.Matches(path)
if e != nil || !match {
return false, e
}
if len(m.methods) == 0 {
return true, nil
}
for _, method := range m.methods {
if r.Method == method {
return true, nil
}
}
return false, nil
}
func (m *loggingRequestMatcher) Matches(i interface{}) (bool, error) {
value, ok := i.(*http.Request)
if !ok {
return false, fmt.Errorf("unsupported type %T", i)
}
return m.RequestMatches(context.TODO(), value)
}
func (m *loggingRequestMatcher) MatchesWithContext(c context.Context, i interface{}) (bool, error) {
value, ok := i.(*http.Request)
if !ok {
return false, fmt.Errorf("unsupported type %T", i)
}
return m.RequestMatches(c, value)
}
func (m *loggingRequestMatcher) Or(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.Or(m, matchers...)
}
func (m *loggingRequestMatcher) And(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.And(m, matchers...)
}
func (m *loggingRequestMatcher) String() string {
return m.description
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package mapping
import (
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"net/http"
)
/*********************************
SimpleMappingBuilder
*********************************/
// MappingBuilder builds web.SimpleMapping
// MappingBuilder.Path, MappingBuilder.Method and MappingBuilder.HandlerFunc are required to successfully build a mapping.
// Supported handler function are gin.HandlerFunc or http.HandlerFunc
// Example:
// <code>
// mapping.Post("/path/to/api").HandlerFunc(func...).Build()
// </code>
//goland:noinspection GoNameStartsWithPackageName
type MappingBuilder struct {
name string
group string
path string
method string
condition web.RequestMatcher
handlerFunc interface{}
}
func New(names ...string) *MappingBuilder {
var name string
if len(names) > 0 {
name = names[0]
}
return &MappingBuilder{
name: name,
method: web.MethodAny,
}
}
// Convenient Constructors
func Any(path string) *MappingBuilder {
return New().Path(path).Method(web.MethodAny)
}
func Get(path string) *MappingBuilder {
return New().Get(path)
}
func Post(path string) *MappingBuilder {
return New().Post(path)
}
func Put(path string) *MappingBuilder {
return New().Put(path)
}
func Patch(path string) *MappingBuilder {
return New().Patch(path)
}
func Delete(path string) *MappingBuilder {
return New().Delete(path)
}
func Options(path string) *MappingBuilder {
return New().Options(path)
}
func Head(path string) *MappingBuilder {
return New().Head(path)
}
/*****************************
Public
******************************/
func (b *MappingBuilder) Name(name string) *MappingBuilder {
b.name = name
return b
}
func (b *MappingBuilder) Group(group string) *MappingBuilder {
b.group = group
return b
}
func (b *MappingBuilder) Path(path string) *MappingBuilder {
b.path = path
return b
}
func (b *MappingBuilder) Method(method string) *MappingBuilder {
b.method = method
return b
}
func (b *MappingBuilder) Condition(condition web.RequestMatcher) *MappingBuilder {
b.condition = condition
return b
}
// HandlerFunc support
// - gin.HandlerFunc
// - http.HandlerFunc
func (b *MappingBuilder) HandlerFunc(handlerFunc interface{}) *MappingBuilder {
switch handlerFunc.(type) {
case gin.HandlerFunc, http.HandlerFunc:
b.handlerFunc = handlerFunc
default:
panic(fmt.Errorf("unsupported HandlerFunc type: %T", handlerFunc))
}
b.handlerFunc = handlerFunc
return b
}
// Convenient setters
func (b *MappingBuilder) Get(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodGet)
}
func (b *MappingBuilder) Post(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodPost)
}
func (b *MappingBuilder) Put(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodPut)
}
func (b *MappingBuilder) Patch(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodPatch)
}
func (b *MappingBuilder) Delete(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodDelete)
}
func (b *MappingBuilder) Options(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodOptions)
}
func (b *MappingBuilder) Head(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodHead)
}
func (b *MappingBuilder) Build() web.SimpleMapping {
if err := b.validate(); err != nil {
panic(err)
}
return b.buildMapping()
}
/*****************************
Getters
******************************/
func (b *MappingBuilder) GetPath() string {
return b.path
}
func (b *MappingBuilder) GetMethod() string {
return b.method
}
func (b *MappingBuilder) GetCondition() web.RequestMatcher {
return b.condition
}
func (b *MappingBuilder) GetName() string {
return b.name
}
/*****************************
Private
******************************/
func (b *MappingBuilder) validate() (err error) {
switch {
case b.path == "" && (b.group == "" || b.group == "/"):
err = errors.New("empty path")
case b.handlerFunc == nil:
err = errors.New("handler func not specified")
}
return
}
func (b *MappingBuilder) buildMapping() web.SimpleMapping {
if b.method == "" {
b.method = web.MethodAny
}
if b.name == "" {
b.name = fmt.Sprintf("%s %s%s", b.method, b.group, b.path)
}
switch handlerFunc := b.handlerFunc.(type) {
case gin.HandlerFunc:
return web.NewSimpleGinMapping(b.name, b.group, b.path, b.method, b.condition, handlerFunc)
case http.HandlerFunc:
return web.NewSimpleMapping(b.name, b.group, b.path, b.method, b.condition, handlerFunc)
default:
panic(fmt.Errorf("unsupported HandlerFunc type: %T", b.handlerFunc))
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package matcher
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"github.com/cisco-open/go-lanai/pkg/web"
"net/http"
"strings"
)
const (
descTmplPath = `path %s`
)
type matchableFunc func(context.Context, *http.Request) (interface{}, error)
// requestMatcher implement web.RequestMatcher
type requestMatcher struct {
description string
matchableFunc matchableFunc
delegate matcher.Matcher
}
func (m *requestMatcher) RequestMatches(c context.Context, r *http.Request) (bool, error) {
if m.matchableFunc == nil {
return m.delegate.MatchesWithContext(c, r)
}
matchable, err := m.matchableFunc(c, r)
if err != nil {
return false, err
}
return m.delegate.MatchesWithContext(c, matchable)
}
func (m *requestMatcher) Matches(i interface{}) (bool, error) {
value, err := interfaceToRequest(i)
if err != nil {
return false, err
}
return m.RequestMatches(context.TODO(), value)
}
func (m *requestMatcher) MatchesWithContext(c context.Context, i interface{}) (bool, error) {
value, err := interfaceToRequest(i)
if err != nil {
return false, err
}
return m.RequestMatches(c, value)
}
func (m *requestMatcher) Or(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.Or(m, matchers...)
}
func (m *requestMatcher) And(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.And(m, matchers...)
}
func (m *requestMatcher) String() string {
switch stringer, ok :=m.delegate.(fmt.Stringer); {
case len(m.description) != 0:
return m.description
case ok:
return stringer.String()
default:
return "web.RequestMatcher"
}
}
/**************************
Constructors
***************************/
func AnyRequest() web.RequestMatcher {
return wrapAsRequestMatcher(matcher.Any())
}
func NoneRequest() web.RequestMatcher {
return wrapAsRequestMatcher(matcher.None())
}
func NotRequest(m web.RequestMatcher) web.RequestMatcher {
return wrapAsRequestMatcher(matcher.Not(m))
}
// RequestWithHost
// TODO support wildcard
func RequestWithHost(expected string) web.RequestMatcher {
delegate := matcher.WithString(expected, true)
return &requestMatcher{
description: fmt.Sprintf("host %s", delegate.(fmt.Stringer).String()),
matchableFunc: host,
delegate: delegate,
}
}
func RequestWithMethods(methods...string) web.RequestMatcher {
var delegate matcher.ChainableMatcher
if len(methods) == 0 {
delegate = matcher.Any()
} else {
delegate = matcher.WithString(methods[0], true)
for _,m := range methods[1:] {
delegate = delegate.Or(matcher.WithString(m, true))
}
}
return &requestMatcher{
description: fmt.Sprintf("method %v", delegate),
matchableFunc: method,
delegate: delegate,
}
}
// RequestWithPattern create a web.RequestMatcher with path pattern.
// if context is available when performing the match, the context path is striped
func RequestWithPattern(pattern string, methods...string) web.RequestMatcher {
pDelegate := matcher.WithPathPattern(pattern)
pMatcher := &requestMatcher{
description: fmt.Sprintf(descTmplPath, pDelegate.(fmt.Stringer).String()),
matchableFunc: path,
delegate: pDelegate,
}
mMatcher := RequestWithMethods(methods...)
return wrapAsRequestMatcher(pMatcher.And(mMatcher))
}
// RequestWithURL is similar with RequestWithPattern, but instead it takes a relative URL path and convert it to pattern
// by extracting "path" part (remove #fragment, ?query and more)
func RequestWithURL(url string, methods...string) web.RouteMatcher {
return RequestWithPattern(PatternFromURL(url), methods...)
}
// RequestWithPrefix create a web.RequestMatcher with prefix
// if context is available when performing the match, the context path is striped
func RequestWithPrefix(prefix string, methods...string) web.RequestMatcher {
pDelegate := matcher.WithPrefix(prefix, true)
pMatcher := &requestMatcher{
description: fmt.Sprintf(descTmplPath, pDelegate.(fmt.Stringer).String()),
matchableFunc: path,
delegate: pDelegate,
}
mMatcher := RequestWithMethods(methods...)
return wrapAsRequestMatcher(pMatcher.And(mMatcher))
}
// RequestWithRegex create a web.RequestMatcher with regular expression
// if context is available when performing the match, the context path is striped
func RequestWithRegex(regex string, methods...string) web.RequestMatcher {
pDelegate := matcher.WithRegex(regex)
pMatcher := &requestMatcher{
description: fmt.Sprintf(descTmplPath, pDelegate.(fmt.Stringer).String()),
matchableFunc: path,
delegate: pDelegate,
}
mMatcher := RequestWithMethods(methods...)
return wrapAsRequestMatcher(pMatcher.And(mMatcher))
}
func RequestWithHeader(name string, value string, prefix bool) web.RequestMatcher {
matchable := func(_ context.Context, r *http.Request) (interface{}, error) {
return r.Header.Get(name), nil
}
var delegate matcher.Matcher
if prefix {
delegate = matcher.WithPrefix(value, true)
} else {
delegate = matcher.WithString(value, true)
}
return &requestMatcher{
description: fmt.Sprintf("matches header %s:%s", name, value),
matchableFunc: matchable,
delegate: delegate,
}
}
func RequestHasHeader(name string) web.RequestMatcher {
matchable := func(_ context.Context, r *http.Request) (interface{}, error) {
return r.Header.Get(name), nil
}
return &requestMatcher{
description: fmt.Sprintf("matches have header %s", name),
matchableFunc: matchable,
delegate: matcher.AnyNonEmptyString(),
}
}
// RequestHasPostForm matches http.Request that have non-empty value with given parameter in query or post body
func RequestHasPostForm(param string) web.RequestMatcher {
return &requestMatcher{
description: fmt.Sprintf(`matches have form parameter [%s] in body`, param),
matchableFunc: postForm(param),
delegate: matcher.AnyNonEmptyString(),
}
}
// RequestHasForm matches http.Request that have non-empty value with given parameter in query or post body
func RequestHasForm(param string) web.RequestMatcher {
return &requestMatcher{
description: fmt.Sprintf(`matches have form parameter [%s]`, param),
matchableFunc: form(param),
delegate: matcher.AnyNonEmptyString(),
}
}
// RequestWithForm matches http.Request that have matching param-value pair in query or post body
func RequestWithForm(param, value string) web.RequestMatcher {
return &requestMatcher{
description: fmt.Sprintf(`matches have form data %s=%s`, param, value),
matchableFunc: query(param),
delegate: matcher.WithString(value, true),
}
}
func CustomMatcher(description string, matchable matchableFunc, delegate matcher.Matcher ) web.RequestMatcher {
return &requestMatcher{
description: description,
matchableFunc: matchable,
delegate: delegate,
}
}
/**************************
helpers
***************************/
func interfaceToRequest(i interface{}) (*http.Request, error) {
switch i.(type) {
case http.Request:
r := i.(http.Request)
return &r, nil
case *http.Request:
return i.(*http.Request), nil
default:
return nil, fmt.Errorf("web.RequestMatcher doesn't support %T", i)
}
}
func wrapAsRequestMatcher(m matcher.Matcher) web.RequestMatcher {
var desc string
if stringer, ok := m.(fmt.Stringer); ok {
desc = stringer.String()
}
return &requestMatcher{
description: desc,
delegate: m,
}
}
func host(_ context.Context, r *http.Request) (interface{}, error) {
return r.Host, nil
}
func method(_ context.Context, r *http.Request) (interface{}, error) {
return r.Method, nil
}
func path(c context.Context, r *http.Request) (interface{}, error) {
path := r.URL.Path
ctxPath := web.ContextPath(c)
return strings.TrimPrefix(path, ctxPath), nil
}
func query(name string) matchableFunc {
return func (c context.Context, r *http.Request) (interface{}, error) {
if e := r.ParseForm(); e != nil {
return nil, e
}
return r.Form.Get(name), nil
}
}
func form(name string) matchableFunc {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
if e := r.ParseForm(); e != nil {
return nil, fmt.Errorf("can't find post form data from request: %v", e)
}
return r.FormValue(name), nil
}
}
func postForm(name string) matchableFunc {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
if e := r.ParseForm(); e != nil {
return nil, fmt.Errorf("can't find post form data from request: %v", e)
}
return r.PostFormValue(name), nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package matcher
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/matcher"
"github.com/cisco-open/go-lanai/pkg/web"
"net/url"
pathutils "path"
"strings"
)
// routeMatcher implement web.RouteMatcher
type routeMatcher struct {
description string
matchableFunc func(*web.Route) interface{}
delegate matcher.Matcher
}
func (m *routeMatcher) RouteMatches(c context.Context, r *web.Route) (bool, error) {
if m.matchableFunc == nil {
return m.delegate.MatchesWithContext(c, r)
}
return m.delegate.MatchesWithContext(c, m.matchableFunc(r))
}
func (m *routeMatcher) Matches(i interface{}) (bool, error) {
value, err := interfaceToRoute(i)
if err != nil {
return false, err
}
return m.RouteMatches(context.TODO(), value)
}
func (m *routeMatcher) MatchesWithContext(c context.Context, i interface{}) (bool, error) {
value, err := interfaceToRoute(i)
if err != nil {
return false, err
}
return m.RouteMatches(c, value)
}
func (m *routeMatcher) Or(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.Or(m, matchers...)
}
func (m *routeMatcher) And(matchers ...matcher.Matcher) matcher.ChainableMatcher {
return matcher.And(m, matchers...)
}
func (m *routeMatcher) String() string {
switch stringer, ok :=m.delegate.(fmt.Stringer); {
case len(m.description) != 0:
return m.description
case ok:
return stringer.String()
default:
return "RouteMatcher"
}
}
/**************************
Constructors
***************************/
func AnyRoute() web.RouteMatcher {
return wrapAsRouteMatcher(matcher.Any())
}
func RouteWithMethods(methods...string) web.RouteMatcher {
var delegate matcher.ChainableMatcher
if len(methods) == 0 {
delegate = matcher.Any()
} else {
delegate = matcher.WithString(methods[0], true)
for _,m := range methods[1:] {
delegate = delegate.Or(matcher.WithString(m, true))
}
}
return &routeMatcher{
description: fmt.Sprintf("method %s", delegate.(fmt.Stringer).String()),
matchableFunc: routeMethod,
delegate: delegate,
}
}
// RouteWithPattern checks web.Route's path with prefix
// The prefix syntax is:
//
// prefix:
// { term }
// term:
// '*' matches any sequence of non-path-separators
// '**' matches any sequence of characters, including
// path separators.
// '?' matches any single non-path-separator character
// '[' [ '^' ] { character-range } ']'
// character class (must be non-empty)
// '{' { term } [ ',' { term } ... ] '}'
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
func RouteWithPattern(pattern string, methods...string) web.RouteMatcher {
pDelegate := matcher.WithPathPattern(pattern)
pMatcher := &routeMatcher{
description: fmt.Sprintf(descTmplPath, pDelegate.(fmt.Stringer).String()),
matchableFunc: routeAbsPath,
delegate: pDelegate,
}
mMatcher := RouteWithMethods(methods...)
return wrapAsRouteMatcher(pMatcher.And(mMatcher))
}
// RouteWithURL is similar with RouteWithPattern, but instead it takes a relative URL path and convert it to pattern
// by extracting "path" part (remove #fragment, ?query and more)
func RouteWithURL(url string, methods...string) web.RouteMatcher {
return RouteWithPattern(PatternFromURL(url), methods...)
}
func RouteWithPrefix(prefix string, methods...string) web.RouteMatcher {
pDelegate := matcher.WithPrefix(prefix, true)
pMatcher := &routeMatcher{
description: fmt.Sprintf(descTmplPath, pDelegate.(fmt.Stringer).String()),
matchableFunc: routeAbsPath,
delegate: pDelegate,
}
mMatcher := RouteWithMethods(methods...)
return wrapAsRouteMatcher(pMatcher.And(mMatcher))
}
func RouteWithRegex(regex string, methods...string) web.RouteMatcher {
pDelegate := matcher.WithRegex(regex)
pMatcher := &routeMatcher{
description: fmt.Sprintf(descTmplPath, pDelegate.(fmt.Stringer).String()),
matchableFunc: routeAbsPath,
delegate: pDelegate,
}
mMatcher := RouteWithMethods(methods...)
return wrapAsRouteMatcher(pMatcher.And(mMatcher))
}
func RouteWithGroup(group string) web.RouteMatcher {
delegate := matcher.WithString(group, false)
return &routeMatcher{
description: fmt.Sprintf("group %s", delegate.(fmt.Stringer).String()),
matchableFunc: routeGroup,
delegate: delegate,
}
}
// PatternFromURL convert relative URL to pattern by necessary operations, such as remove #fragment portion
func PatternFromURL(relativeUrl string) string {
u, e := url.Parse(relativeUrl)
if e != nil {
split := strings.SplitN(relativeUrl, "#", 2)
return split[0]
}
return u.Path
}
/**************************
helpers
***************************/
func interfaceToRoute(i interface{}) (*web.Route, error) {
switch v := i.(type) {
case web.Route:
return &v, nil
case *web.Route:
return v, nil
default:
return nil, fmt.Errorf("RouteMatcher doesn't support %T", i)
}
}
func routeGroup(r *web.Route) interface{} {
return r.Group
}
func routeMethod(r *web.Route) interface{} {
return r.Method
}
func routeAbsPath(r *web.Route) interface{} {
p := pathutils.Join(r.Group, r.Path)
if !pathutils.IsAbs(p) {
p = "/" + p
}
return pathutils.Clean(p)
}
func wrapAsRouteMatcher(m matcher.Matcher) web.RouteMatcher {
var desc string
if stringer, ok := m.(fmt.Stringer); ok {
desc = stringer.String()
}
return &routeMatcher{
description: desc,
delegate: m,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import "net/http"
// ConditionalMiddleware is an additional interface that a Middleware can implement to control when the middleware is applied.
// e.g. a middleware want to be applied if request's header contains "Authorization"
type ConditionalMiddleware interface {
Condition() RequestMatcher
}
// Middleware defines a http.HandlerFunc to be used by MiddlewareMapping and middleware.MappingBuilder
type Middleware interface {
HandlerFunc() http.HandlerFunc
}
type middlewareMapping struct {
name string
order int
matcher RouteMatcher
condition RequestMatcher
handlerFunc http.HandlerFunc
}
// NewMiddlewareMapping create a MiddlewareMapping with http.HandlerFunc
// It's recommended to use middleware.MappingBuilder instead of this function:
// e.g.
// <code>
// middleware.NewBuilder("my-auth").Order(-10).Use(func...).Build()
// </code>
func NewMiddlewareMapping(name string, order int, matcher RouteMatcher, cond RequestMatcher, handlerFunc http.HandlerFunc) MiddlewareMapping {
return &middlewareMapping {
name: name,
matcher: matcher,
order: order,
condition: cond,
handlerFunc: handlerFunc,
}
}
func (mm middlewareMapping) Name() string {
return mm.name
}
func (mm middlewareMapping) Matcher() RouteMatcher {
return mm.matcher
}
func (mm middlewareMapping) Order() int {
return mm.order
}
func (mm middlewareMapping) Condition() RequestMatcher {
return mm.condition
}
func (mm middlewareMapping) HandlerFunc() http.HandlerFunc {
return mm.handlerFunc
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"net/http"
)
// MappingBuilder builds web.MiddlewareMapping
// Either MappingBuilder.Use or MappingBuilder.With should be set in order to successfully build a mapping.
// If MappingBuilder.ApplyTo is not set, the middleware intercept all routes.
// Supported handler function are gin.HandlerFunc or http.HandlerFunc
// Example:
// <code>
// middleware.NewBuilder("my-auth").Order(-10).ApplyTo(matcher.RequestWithPattern("/api/v1/**")).Use(func...).Build()
// </code>
type MappingBuilder struct {
name string
middleware web.Middleware
matcher web.RouteMatcher
order int
// overrides
condition web.RequestMatcher
handlerFunc interface{}
}
func NewBuilder(names ...string) *MappingBuilder {
name := "unknown"
if len(names) > 0 {
name = names[0]
}
return &MappingBuilder{
name: name,
order: 0,
}
}
/*****************************
Public
******************************/
func (b *MappingBuilder) Name(name string) *MappingBuilder {
b.name = name
return b
}
func (b *MappingBuilder) Order(order int) *MappingBuilder {
b.order = order
return b
}
func (b *MappingBuilder) With(middleware web.Middleware) *MappingBuilder {
b.middleware = middleware
return b
}
func (b *MappingBuilder) ApplyTo(matcher web.RouteMatcher) *MappingBuilder {
b.matcher = matcher
return b
}
// Use set middleware handler. Support:
// - gin.HandlerFunc
// - http.HandlerFunc
func (b *MappingBuilder) Use(handlerFunc interface{}) *MappingBuilder {
switch handlerFunc.(type) {
case gin.HandlerFunc, http.HandlerFunc:
b.handlerFunc = handlerFunc
default:
panic(fmt.Errorf("unsupported HandlerFunc type: %T", handlerFunc))
}
return b
}
func (b *MappingBuilder) WithCondition(condition web.RequestMatcher) *MappingBuilder {
b.condition = condition
return b
}
func (b *MappingBuilder) Build() web.MiddlewareMapping {
var condition web.RequestMatcher
var handlerFunc interface{}
if b.middleware != nil {
handlerFunc = b.middleware.HandlerFunc()
if conditional, ok := b.middleware.(web.ConditionalMiddleware); ok {
condition = conditional.Condition()
}
}
if b.handlerFunc != nil {
handlerFunc = b.handlerFunc
}
if b.condition != nil {
condition = b.condition
}
switch v := handlerFunc.(type) {
case gin.HandlerFunc:
return web.NewMiddlewareGinMapping(b.name, b.order, b.matcher, condition, v)
case http.HandlerFunc:
return web.NewMiddlewareMapping(b.name, b.order, b.matcher, condition, v)
default:
panic(fmt.Errorf("unable to build '%s' middleware mapping: unsupported HandlerFunc type %v. please use With(...) or Use(...)", b.name, handlerFunc))
}
}
/*****************************
Getters
******************************/
func (b *MappingBuilder) GetRouteMatcher() web.RouteMatcher {
return b.matcher
}
func (b *MappingBuilder) GetCondition() web.RequestMatcher {
return b.condition
}
func (b *MappingBuilder) GetName() string {
return b.name
}
func (b *MappingBuilder) GetOrder() int {
return b.order
}
/*****************************
Helpers
******************************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"net/http"
)
type mvcMapping struct {
name string
group string
path string
method string
condition RequestMatcher
decodeRequestFunc DecodeRequestFunc
encodeResponseFunc EncodeResponseFunc
encodeErrorFunc EncodeErrorFunc
endpoint MvcHandlerFunc
}
// NewMvcMapping create a MvcMapping
// It's recommended to use rest.MappingBuilder or template.MappingBuilder instead of this function:
// e.g.
// <code>
// rest.Put("/path/to/api").EndpointFunc(func...).Build()
// template.Post("/path/to/page").HandlerFunc(func...).Build()
// </code>
func NewMvcMapping(name, group, path, method string, condition RequestMatcher,
mvcHandlerFunc MvcHandlerFunc,
decodeRequestFunc DecodeRequestFunc,
encodeResponseFunc EncodeResponseFunc,
errorEncoder EncodeErrorFunc) MvcMapping {
return &mvcMapping{
name: name,
group: group,
path: path,
method: method,
condition: condition,
endpoint: mvcHandlerFunc,
decodeRequestFunc: decodeRequestFunc,
encodeResponseFunc: encodeResponseFunc,
encodeErrorFunc: errorEncoder,
}
}
/*****************************
MvcMapping Interface
******************************/
func (m *mvcMapping) Name() string {
return m.name
}
func (m *mvcMapping) Group() string {
return m.group
}
func (m *mvcMapping) Path() string {
return m.path
}
func (m *mvcMapping) Method() string {
return m.method
}
func (m *mvcMapping) Condition() RequestMatcher {
return m.condition
}
func (m *mvcMapping) DecodeRequestFunc() DecodeRequestFunc {
return m.decodeRequestFunc
}
func (m *mvcMapping) EncodeResponseFunc() EncodeResponseFunc {
return m.encodeResponseFunc
}
func (m *mvcMapping) EncodeErrorFunc() EncodeErrorFunc {
return m.encodeErrorFunc
}
func (m *mvcMapping) HandlerFunc() MvcHandlerFunc {
return m.endpoint
}
/*********************
Response
**********************/
type Response struct {
SC int
H http.Header
B interface{}
}
// StatusCode implements StatusCoder
func (r Response) StatusCode() int {
if i, ok := r.B.(StatusCoder); ok {
return i.StatusCode()
}
return r.SC
}
// Headers implements Headerer
func (r Response) Headers() http.Header {
if i, ok := r.B.(Headerer); ok {
return i.Headers()
}
return r.H
}
// Body implements BodyContainer
func (r Response) Body() interface{} {
if i, ok := r.B.(BodyContainer); ok {
return i.Body()
}
return r.B
}
/**********************************
LazyHeaderWriter
***********************************/
// LazyHeaderWriter makes sure that status code and headers is overwritten at last second (when invoke Write([]byte) (int, error).
// Calling WriteHeader(int) would not actually send the header. Calling it multiple times to update status code
// Doing so allows response encoder and error handling to send different header and status code
type LazyHeaderWriter struct {
http.ResponseWriter
sc int
header http.Header
}
func (w *LazyHeaderWriter) Header() http.Header {
return w.header
}
func (w *LazyHeaderWriter) WriteHeader(code int) {
w.sc = code
}
func (w *LazyHeaderWriter) Write(p []byte) (int, error) {
w.WriteHeaderNow()
return w.ResponseWriter.Write(p)
}
func (w *LazyHeaderWriter) WriteHeaderNow() {
// Merge header overwrite
for k, v := range w.header {
w.ResponseWriter.Header()[k] = v
}
w.ResponseWriter.WriteHeader(w.sc)
}
func NewLazyHeaderWriter(w http.ResponseWriter) *LazyHeaderWriter {
// make a copy of current header from wrapped writer
header := make(http.Header)
for k, v := range w.Header() {
header[k] = v
}
return &LazyHeaderWriter{ResponseWriter: w, sc: http.StatusOK, header: header}
}
/*********************
MVC Handler
**********************/
type mvcHandler struct {
reqDecoder DecodeRequestFunc
respEncoder EncodeResponseFunc
errEncoder EncodeErrorFunc
handlerFunc MvcHandlerFunc
}
func makeMvcHttpHandlerFunc(m MvcMapping, opts ...func(h *mvcHandler)) http.HandlerFunc {
handler := &mvcHandler{
reqDecoder: m.DecodeRequestFunc(),
respEncoder: m.EncodeResponseFunc(),
errEncoder: m.EncodeErrorFunc(),
handlerFunc: m.HandlerFunc(),
}
for _, fn := range opts {
fn(handler)
}
return func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
request, e := handler.reqDecoder(ctx, r)
if e != nil {
handler.errEncoder(ctx, e, rw)
return
}
response, e := handler.handlerFunc(ctx, request)
if e != nil {
handler.errEncoder(ctx, e, rw)
return
}
if e := handler.respEncoder(ctx, rw, response); e != nil {
handler.errEncoder(ctx, e, rw)
return
}
}
}
package web
import (
"context"
"errors"
"io"
"net/http"
"reflect"
)
/**********************************
Request Decoder
***********************************/
// GinBindingRequestDecoder is a DecodeRequestFunc utilizing gin.Context's binding capabilities.
// The decoder uses the provided function to instantiate the object.
// If the instantiateFunc returns a non-pointer value, the decoder uses reflect to find its pointer
func GinBindingRequestDecoder(instantiateFunc func() interface{}) DecodeRequestFunc {
// decode request using gin.Context's bind functions
return func(c context.Context, r *http.Request) (request interface{}, err error) {
ginCtx := GinContext(c)
if ginCtx == nil {
return nil, NewHttpError(http.StatusInternalServerError, errors.New("context issue"))
}
toBind, toRet := resolveBindable(instantiateFunc())
// We always try to bind H, Uri and Query. other bindings are determined by Content-Type (in ShouldBind)
err = bind(toBind,
ginCtx.ShouldBindHeader,
ginCtx.ShouldBindUri,
ginCtx.ShouldBindQuery)
if err != nil {
return nil, translateBindingError(err)
}
err = ginCtx.ShouldBind(toBind)
if err != nil && !(errors.Is(err, io.EOF) && r.ContentLength <= 0) {
return nil, translateBindingError(err)
}
return toRet.Interface(), validateBinding(c, toBind)
}
}
type bindingFunc func(interface{}) error
func bind(obj interface{}, bindings ...bindingFunc) (err error) {
for _, b := range bindings {
if err = b(obj); err != nil {
return
}
}
return
}
func translateBindingError(err error) error {
return NewBindingError(err)
}
func validateBinding(ctx context.Context, obj interface{}) error {
if bindingValidator == nil {
return nil
}
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
return bindingValidator.StructCtx(ctx, obj)
default:
return nil
}
}
// resolveBindable using reflection to resolve bindable pointer of actual value.
func resolveBindable(i interface{}) (bindablePtr interface{}, actual reflect.Value) {
switch v := i.(type) {
case reflect.Value:
actual = v
default:
actual = reflect.ValueOf(i)
}
switch actual.Kind() {
case reflect.Ptr:
bindablePtr = actual.Interface()
default:
if actual.CanAddr() {
bindablePtr = actual.Addr().Interface()
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"context"
"encoding"
"encoding/json"
"errors"
"fmt"
"net/http"
)
/**********************************
Various Response Encoders
***********************************/
type EncodeOptions func(opt *EncodeOption)
type EncodeOption struct {
ContentType string
Writer http.ResponseWriter
Response interface{}
WriteFunc func(rw http.ResponseWriter, v interface{}) error
}
func JsonResponseEncoder() EncodeResponseFunc {
return jsonEncodeResponseFunc
}
func TextResponseEncoder() EncodeResponseFunc {
return textEncodeResponseFunc
}
func BytesResponseEncoder() EncodeResponseFunc {
return bytesEncodeResponseFunc
}
func CustomResponseEncoder(opts ...EncodeOptions) EncodeResponseFunc {
return func(c context.Context, rw http.ResponseWriter, response interface{}) error {
opts := append([]EncodeOptions{
func(opt *EncodeOption) {
opt.Writer = rw
opt.Response = response
},
}, opts...)
return encodeResponse(c, opts...)
}
}
/**********************************
JSON Response Encoder
***********************************/
func JsonWriteFunc(rw http.ResponseWriter, v interface{}) error {
return json.NewEncoder(rw).Encode(v)
}
func jsonEncodeResponseFunc(c context.Context, rw http.ResponseWriter, response interface{}) error {
return encodeResponse(c, func(opt *EncodeOption) {
opt.ContentType = "application/json; charset=utf-8"
opt.Writer = rw
opt.Response = response
opt.WriteFunc = JsonWriteFunc
})
}
/**********************************
Text Response Encoder
***********************************/
func TextWriteFunc(rw http.ResponseWriter, v interface{}) error {
var data []byte
switch v.(type) {
case []byte:
data = v.([]byte)
case string:
data = []byte(v.(string))
case fmt.Stringer:
data = []byte(v.(fmt.Stringer).String())
case encoding.TextMarshaler:
t, e := v.(encoding.TextMarshaler).MarshalText()
if e != nil {
return e
}
data = t
default:
return NewHttpError(http.StatusInternalServerError, errors.New("invalid response type"))
}
_, e := rw.Write(data)
return e
}
func textEncodeResponseFunc(c context.Context, rw http.ResponseWriter, response interface{}) error {
return encodeResponse(c, func(opt *EncodeOption) {
opt.ContentType = "text/plain; charset=utf-8"
opt.Writer = rw
opt.Response = response
opt.WriteFunc = TextWriteFunc
})
}
/**********************************
Bytes Response Encoder
***********************************/
func BytesWriteFunc(rw http.ResponseWriter, v interface{}) error {
var data []byte
switch v.(type) {
case []byte:
data = v.([]byte)
case string:
data = []byte(v.(string))
case encoding.BinaryMarshaler:
t, e := v.(encoding.BinaryMarshaler).MarshalBinary()
if e != nil {
return e
}
data = t
default:
return NewHttpError(http.StatusInternalServerError, errors.New("invalid response type"))
}
_, e := rw.Write(data)
return e
}
func bytesEncodeResponseFunc(c context.Context, rw http.ResponseWriter, response interface{}) error {
return encodeResponse(c, func(opt *EncodeOption) {
opt.ContentType = "application/octet-stream"
opt.Writer = rw
opt.Response = response
opt.WriteFunc = BytesWriteFunc
})
}
/**********************************
Response Encoding Helpers
***********************************/
// encodeResponse work with endpoint generated with MakeEndpoint
// we could export this function if needed. But for now, it remains hidden
func encodeResponse(_ context.Context, opts ...EncodeOptions) error {
opt := EncodeOption{}
for _, f := range opts {
f(&opt)
}
// overwrite headers
if headerer, ok := opt.Response.(Headerer); ok {
opt.Writer = NewLazyHeaderWriter(opt.Writer)
overwriteHeaders(opt.Writer, headerer)
}
// additional headers
opt.Writer.Header().Set("Content-Type", opt.ContentType)
// write header and status code
if coder, ok := opt.Response.(StatusCoder); ok {
opt.Writer.WriteHeader(coder.StatusCode())
}
if entity, ok := opt.Response.(BodyContainer); ok {
opt.Response = entity.Body()
}
// we just ignore nil pointer
switch resp := opt.Response.(type) {
case nil:
_, e := opt.Writer.Write([]byte{})
return e
default:
return opt.WriteFunc(opt.Writer, resp)
}
}
func overwriteHeaders(w http.ResponseWriter, h Headerer) {
for key, values := range h.Headers() {
for i, val := range values {
if i == 0 {
w.Header().Set(key, val)
} else {
w.Header().Add(key, val)
}
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/pkg/errors"
)
/***********************
Server
************************/
const (
ServerPropertiesPrefix = "server"
)
type ServerProperties struct {
Port int `json:"port"`
ContextPath string `json:"context-path"`
Logging LoggingProperties `json:"logging"`
}
type LoggingProperties struct {
Enabled bool `json:"enabled"`
DefaultLevel log.LoggingLevel `json:"default-level"`
Levels map[string]LoggingLevelProperties `json:"levels"`
}
// LoggingLevelProperties is used to override logging level on particular set of paths
// the LoggingProperties.Pattern support wildcard and should not include "context-path"
// the LoggingProperties.Method is space separated values. If left blank or contains "*", it matches all methods
type LoggingLevelProperties struct {
Method string `json:"method"`
Pattern string `json:"pattern"`
Level log.LoggingLevel `json:"level"`
}
// NewServerProperties create a ServerProperties with default values
func NewServerProperties() *ServerProperties {
return &ServerProperties{
Port: -1,
ContextPath: "/",
Logging: LoggingProperties{
Enabled: true,
DefaultLevel: log.LevelDebug,
Levels: map[string]LoggingLevelProperties{},
},
}
}
// BindServerProperties create and bind a ServerProperties using default prefix
func BindServerProperties(ctx *bootstrap.ApplicationContext) ServerProperties {
props := NewServerProperties()
if err := ctx.Config().Bind(props, ServerPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind ServerProperties"))
}
return *props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"context"
"github.com/gin-gonic/gin"
)
// RecoveryCustomizer implements Customizer
type RecoveryCustomizer struct {
}
func NewRecoveryCustomizer() *RecoveryCustomizer {
return &RecoveryCustomizer{}
}
func (c RecoveryCustomizer) Customize(ctx context.Context, r *Registrar) error {
return r.AddGlobalMiddlewares(gin.Recovery())
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/utils/reflectutils"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"go.uber.org/fx"
"html/template"
"io/fs"
"net"
"net/http"
pathutils "path"
"reflect"
"sort"
"strings"
"time"
)
//goland:noinspection GoUnusedConst
const (
DefaultGroup = "/"
)
type Registrar struct {
engine *Engine
router gin.IRouter
server *http.Server
port int
properties ServerProperties
validator *Validate
requestRewriter RequestRewriter
middlewares []MiddlewareMapping // middlewares gin-gonic middleware providers
routedMappings routedMappings // routedMappings MvcMappings + SimpleMappings
staticMappings []StaticMapping // staticMappings all static mappings
customizers []Customizer
errMappings []ErrorTranslateMapping
errTranslators []ErrorTranslator
embedFs []fs.FS
initialized bool
warnDuplicateMWs bool
warnExclusion utils.StringSet
}
func NewRegistrar(g *Engine, properties ServerProperties) *Registrar {
var contextPath = pathutils.Clean("/" + properties.ContextPath)
registrar := &Registrar{
engine: g,
router: g.Group(contextPath),
properties: properties,
validator: bindingValidator,
requestRewriter: newGinRequestRewriter(g.Engine),
routedMappings: routedMappings{},
warnDuplicateMWs: true,
warnExclusion: utils.NewStringSet(),
}
return registrar
}
// Initialize should be called during application startup, last change to change configurations, load templates, etc
// Note: This function is exported for test utilities. Normal applications should use Registrar.Run which invokes this function internally.
func (r *Registrar) Initialize(ctx context.Context) (err error) {
if r.initialized {
return fmt.Errorf("attempting to initialize web engine multiple times")
}
// first, we add some mandatory customizers and middleware
r.MustRegister(NewPriorityGinContextCustomizer(&r.properties))
// apply customizers before install mappings
if err = r.applyCustomizers(ctx); err != nil {
return
}
// we disable auto-validation. We will invoke our own validation manually.
// Also we need to make the validator available globally for any request decoder to access.
// The alternative approach is to put the validator into each gin.Context
binding.Validator = nil
// load templates
r.loadHtmlTemplates(ctx)
// add some common middlewares
var mappings []interface{}
if err = r.Register(mappings...); err != nil {
return
}
// before starting to register mappings, we want global MW to take effect on our main group
var contextPath = pathutils.Clean("/" + r.properties.ContextPath)
r.router = r.engine.Group(contextPath)
// register routedMappings to gin engine
if err = r.installMappings(ctx); err != nil {
return
}
r.initialized = true
return
}
// Cleanup kick off post initialization cleanups
// Note: This function is exported for test utilities. Normal applications should use Registrar.Run which invokes this function internally.
func (r *Registrar) Cleanup(ctx context.Context) (err error) {
if e := r.applyPostInitCustomizers(ctx); e != nil {
return e
}
return nil
}
// AddGlobalMiddlewares add middleware to all mapping
func (r *Registrar) AddGlobalMiddlewares(handlerFuncs ...gin.HandlerFunc) error {
r.engine.Use(handlerFuncs...)
return nil
}
// AddEngineOptions customize Engine
func (r *Registrar) AddEngineOptions(opts ...EngineOptions) error {
for _, fn := range opts {
fn(r.engine)
}
return nil
}
func (r *Registrar) WarnDuplicateMiddlewares(ifWarn bool, excludedPath ...string) {
r.warnDuplicateMWs = ifWarn
r.warnExclusion.Add(excludedPath...)
}
// Run configure and start gin engine
func (r *Registrar) Run(ctx context.Context) (err error) {
if err = r.Initialize(ctx); err != nil {
return
}
defer func(ctx context.Context) {
_ = r.Cleanup(ctx)
}(ctx)
// we let system to choose port if not set
var addr = fmt.Sprintf(":%v", r.properties.Port)
if r.properties.Port <= 0 {
addr = ":0"
}
r.server = &http.Server{
Addr: addr,
Handler: r.engine,
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
MaxHeaderBytes: 1 << 20,
}
// start the server
tcpAddr, e := r.listenAndServe()
if e == nil {
r.port = tcpAddr.Port
}
return e
}
// Stop closes http server
func (r *Registrar) Stop(ctx context.Context) (err error) {
if r.server == nil {
return fmt.Errorf("attempt to stop server before initialization")
}
err = r.server.Close()
if err != nil {
logger.WithContext(ctx).Warnf("error when stop http server: %v", err)
} else {
logger.WithContext(ctx).Infof("http server stopped")
}
return
}
// ServerPort returns the port of started server, returns 0 if server is not initialized
func (r *Registrar) ServerPort() int {
return r.port
}
// Register is the entry point to register Controller, Mapping and other web related objects
// supported items type are:
// - Customizer
// - Controller
// - EndpointMapping
// - StaticMapping
// - TemplateMapping
// - MiddlewareMapping
// - ErrorTranslateMapping
// - ErrorTranslator
// - struct that contains exported Controller fields
// - fs.FS
func (r *Registrar) Register(items ...interface{}) (err error) {
for _, i := range items {
if err = r.register(i); err != nil {
break
}
}
return
}
func (r *Registrar) MustRegister(items ...interface{}) {
if e := r.Register(items...); e != nil {
panic(e)
}
}
func (r *Registrar) listenAndServe() (*net.TCPAddr, error) {
ln, err := net.Listen("tcp", r.server.Addr)
if err != nil {
return nil, err
}
go func() {
_ = r.server.Serve(ln)
}()
return ln.Addr().(*net.TCPAddr), nil
}
func (r *Registrar) register(i interface{}) (err error) {
if r.initialized {
return errors.New("attempting to register mappings/middlewares/pre-processors after web engine initialization")
}
switch v := i.(type) {
case Controller:
err = r.registerController(v)
case MvcMapping:
err = r.registerMvcMapping(v)
case StaticMapping:
err = r.registerStaticMapping(v)
case MiddlewareMapping:
err = r.registerMiddlewareMapping(v)
case SimpleMapping:
err = r.registerSimpleMapping(v)
case ErrorTranslateMapping:
err = r.registerErrorMapping(v)
case RequestPreProcessor:
err = r.registerRequestPreProcessor(v)
case Customizer:
err = r.registerWebCustomizer(v)
case ErrorTranslator:
err = r.registerErrorTranslator(v)
case fs.FS:
r.embedFs = append(r.embedFs, v)
default:
err = r.registerUnknownType(i)
}
return
}
func (r *Registrar) registerUnknownType(i interface{}) (err error) {
v := reflect.ValueOf(i)
for ; v.Kind() == reflect.Ptr; v = v.Elem() {
// SuppressWarnings go:S108 empty block is intended
}
var valid bool
switch {
case v.Kind() == reflect.Slice:
for i := 0; i < v.Len(); i++ {
if e := r.register(v.Index(i).Interface()); e != nil {
return e
}
}
// empty slice doesn't count as error
valid = true
case v.Kind() == reflect.Struct:
// go through fields and register
for idx := 0; idx < v.NumField(); idx++ {
// only care controller fields
if f := v.Type().Field(idx); !reflectutils.IsExportedField(f) {
// unexported field
continue
}
c := v.Field(idx).Interface()
switch c.(type) {
case fx.In:
valid = true
case Controller:
valid = true
if e := r.register(c); e != nil {
return e
}
}
}
}
if !valid {
return errors.New(fmt.Sprintf("unsupported type [%T]", i))
}
return
}
func (r *Registrar) registerController(c Controller) (err error) {
endpoints := c.Mappings()
for _, m := range endpoints {
if err = r.register(m); err != nil {
err = fmt.Errorf("invalid endpoint mapping in Controller [%T]: %v", c, err.Error())
break
}
}
return
}
func (r *Registrar) registerRoutedMapping(m RoutedMapping) error {
method := strings.ToUpper(m.Method())
path := NormalizedPath(m.Path())
group := DefaultGroup
if m.Group() != "" {
group = m.Group()
}
paths := r.routedMappings.GetOrNew(method).GetOrNew(group)
mappings := paths.GetOrNew(path)
mappings = append(mappings, m)
paths[path] = mappings
return nil
}
func (r *Registrar) registerSimpleMapping(m SimpleMapping) error {
return r.registerRoutedMapping(m)
}
func (r *Registrar) registerMvcMapping(m MvcMapping) error {
return r.registerRoutedMapping(m)
}
func (r *Registrar) registerStaticMapping(m StaticMapping) error {
r.staticMappings = append(r.staticMappings, m)
return nil
}
func (r *Registrar) registerMiddlewareMapping(m MiddlewareMapping) error {
r.middlewares = append(r.middlewares, m)
return nil
}
func (r *Registrar) registerRequestPreProcessor(p RequestPreProcessor) error {
r.engine.addRequestPreProcessor(p)
return nil
}
func (r *Registrar) registerWebCustomizer(c Customizer) error {
if r.initialized {
return fmt.Errorf("cannot register web configurer after web engine have initialized")
}
r.customizers = append(r.customizers, c)
order.SortStable(r.customizers, order.OrderedFirstCompare)
return nil
}
func (r *Registrar) registerErrorMapping(m ErrorTranslateMapping) error {
if r.initialized {
return fmt.Errorf("cannot register error mappings after web engine have initialized")
}
r.errMappings = append(r.errMappings, m)
return nil
}
func (r *Registrar) registerErrorTranslator(t ErrorTranslator) error {
if r.initialized {
return fmt.Errorf("cannot register error translator after web engine have initialized")
}
r.errTranslators = append(r.errTranslators, t)
return nil
}
func (r *Registrar) applyCustomizers(ctx context.Context) error {
for _, c := range r.customizers {
if e := c.Customize(ctx, r); e != nil {
return e
}
}
return nil
}
func (r *Registrar) applyPostInitCustomizers(ctx context.Context) error {
if r.customizers == nil {
return nil
}
for _, c := range r.customizers {
if pi, ok := c.(PostInitCustomizer); ok {
if e := pi.PostInit(ctx, r); e != nil {
return e
}
}
}
return nil
}
func (r *Registrar) installMappings(ctx context.Context) error {
// before registering, we need to add default error translators
order.SortStable(r.errMappings, order.OrderedFirstCompare)
r.errTranslators = append(r.errTranslators, newDefaultErrorTranslator())
// register routedMappings
for method, groups := range r.routedMappings {
for group, paths := range groups {
for _, mappings := range paths {
// all routedMappings with condition registered first
sort.SliceStable(mappings, func(i, j int) bool {
return mappings[i].Condition() != nil && mappings[j].Condition() == nil
})
if e := r.installRoutedMappings(ctx, method, group, mappings); e != nil {
return e
}
}
}
}
// register static mappings
for _, m := range r.staticMappings {
if e := r.installStaticMapping(ctx, m); e != nil {
return e
}
}
return nil
}
func (r *Registrar) installStaticMapping(ctx context.Context, m StaticMapping) error {
embedded := make([]fs.FS, len(r.embedFs))
for i, fsys := range r.embedFs {
embedded[i] = OrderedFS(NewDirFS(m.StaticRoot(), fsys), i)
}
mFs := NewMergedFS(OrderedFS(NewOSDirFS(m.StaticRoot()), order.Highest), embedded...)
mw := ginStaticAssetsHandler{
rewriter: r.requestRewriter,
fsys: mFs,
aliases: m.Aliases(),
}
middlewares, err := r.findMiddlewares(ctx, DefaultGroup, m.Path(), http.MethodGet, http.MethodHead)
middlewares = append(gin.HandlersChain{mw.FilenameRewriteHandlerFunc()}, middlewares...)
middlewares = append(middlewares, mw.PreCompressedGzipAsset())
r.router.Group(DefaultGroup).
Use(middlewares...).
StaticFS(m.Path(), http.FS(mFs))
return err
}
//nolint:contextcheck // context is only for logging purpose
func (r *Registrar) installRoutedMappings(ctx context.Context, method, group string, mappings []RoutedMapping) error {
if len(mappings) == 0 {
return nil
}
if group == "" {
group = DefaultGroup
}
path := mappings[0].Path()
// resolve error translators first
errTranslators, e := r.findErrorTranslators(ctx, group, path, method)
if e != nil {
return fmt.Errorf("unable to resolve error translation for [%s %s]", method, path)
}
order.SortStable(errTranslators, order.OrderedFirstCompare)
// resolve gin.HandlerFunc to register
handlerFuncs := make([]gin.HandlerFunc, len(mappings))
unconditionalFound := false
for i, m := range mappings {
// validate method and path with best efforts
switch {
case path != m.Path():
return fmt.Errorf("attempt to register multiple RoutedMappings with inconsist path parameters: "+
"expected [%s (%s)%s] but got [%s (%s)%s]", method, group, path, m.Method(), m.Group(), m.Path())
case m.Condition() == nil && unconditionalFound:
return fmt.Errorf("attempt to register multiple unconditional RoutedMappings on same path and method: [%s %s]", m.Method(), m.Path())
case m.Condition() == nil:
unconditionalFound = true
}
// create handler func
switch m.(type) {
case MvcMapping:
handlerFuncs[i] = r.makeHandlerFuncFromMvcMapping(m.(MvcMapping), errTranslators)
case SimpleGinMapping:
handlerFuncs[i] = r.makeGinConditionalHandlerFunc(m.(SimpleGinMapping).GinHandlerFunc(), m.Condition())
case SimpleMapping:
f := NewHttpGinHandlerFunc(m.(SimpleMapping).HandlerFunc())
handlerFuncs[i] = r.makeGinConditionalHandlerFunc(f, m.Condition())
}
}
// find middleware and register with router
middlewares, err := r.findMiddlewares(ctx, group, path, method)
if method == MethodAny {
r.router.Group(group).
Use(middlewares...).
Any(path, handlerFuncs...)
} else {
r.router.Group(group).
Use(middlewares...).
Handle(method, path, handlerFuncs...)
}
return err
}
//nolint:contextcheck // context is only for logging purpose
func (r *Registrar) findMiddlewares(ctx context.Context, group, relativePath string, methods ...string) (gin.HandlersChain, error) {
var handlers = make([]gin.HandlerFunc, len(r.middlewares))
var matchedMW = make([]MiddlewareMapping, len(r.middlewares))
sort.SliceStable(r.middlewares, func(i, j int) bool { return r.middlewares[i].Order() < r.middlewares[j].Order() })
var i = 0
path := NormalizedPath(relativePath)
for _, mw := range r.middlewares {
switch match, err := r.routeMatches(mw.Matcher(), group, path, methods...); {
case err != nil:
return []gin.HandlerFunc{}, err
case match:
var f gin.HandlerFunc
switch mw.(type) {
case MiddlewareGinMapping:
f = mw.(MiddlewareGinMapping).GinHandlerFunc()
default:
f = NewHttpGinHandlerFunc(http.HandlerFunc(mw.HandlerFunc()))
}
handlers[i] = r.makeGinConditionalHandlerFunc(f, mw.Condition())
matchedMW[i] = mw
i++
}
}
// warn duplicate MWs
if r.warnDuplicateMWs {
r.logMatchedMiddlewares(ctx, matchedMW[:i], group, relativePath, methods)
}
return handlers[:i], nil
}
func (r *Registrar) findErrorTranslators(_ context.Context, group, relativePath string, methods ...string) ([]ErrorTranslator, error) {
var translators = make([]ErrorTranslator, len(r.errTranslators), len(r.errTranslators)+len(r.errMappings))
for i, t := range r.errTranslators {
translators[i] = t
}
var matched = make([]ErrorTranslateMapping, len(r.errMappings))
path := NormalizedPath(relativePath)
for i, m := range r.errMappings {
switch match, err := r.routeMatches(m.Matcher(), group, path, methods...); {
case err != nil:
return translators, err
case match:
matched[i] = m
translators = append(translators, newMappedErrorTranslator(m))
}
}
return translators, nil
}
func (r *Registrar) routeMatches(matcher RouteMatcher, group, relativePath string, methods ...string) (bool, error) {
switch {
case len(methods) == 0:
return false, fmt.Errorf("unable to register middleware: method is missing for %s", relativePath)
case matcher == nil:
return true, nil // no matcher, any value is a match
}
// match if any given method matches
for _, m := range methods {
ret, err := matcher.Matches(Route{Group: group, Path: relativePath, Method: m})
if ret || err != nil {
return ret, err
}
}
return false, nil
}
func (r *Registrar) loadHtmlTemplates(ctx context.Context) {
osFS := NewOSDirFS("web/", DirFSAllowListDirectory)
mFs := NewMergedFS(osFS, r.embedFs...)
t, e := template.New("html").
Funcs(r.engine.FuncMap).
ParseFS(mFs, "**/*.tmpl")
if e != nil {
logger.WithContext(ctx).Infof("no templates loaded: %v", e)
return
}
r.engine.SetHTMLTemplate(t)
}
/**************************
Helpers
***************************/
func (r *Registrar) makeHandlerFuncFromMvcMapping(m MvcMapping, errTranslators []ErrorTranslator) gin.HandlerFunc {
handlerFunc := makeMvcHttpHandlerFunc(m, func(h *mvcHandler) {
h.errEncoder = newErrorEncoder(h.errEncoder, errTranslators...)
})
return r.makeGinConditionalHandlerFunc(NewHttpGinHandlerFunc(handlerFunc), m.Condition())
}
// makeGinConditionalHandlerFunc wraps given handler with a request matcher
func (r *Registrar) makeGinConditionalHandlerFunc(handler gin.HandlerFunc, rm RequestMatcher) gin.HandlerFunc {
if rm == nil {
return handler
}
return func(c *gin.Context) {
if matches, e := rm.MatchesWithContext(c, c.Request); e == nil && matches {
handler(c)
} else if e != nil {
_ = c.Error(e)
c.Abort()
}
}
}
// logMatchedMiddlewares logs important information about middlewares, majorly for debug and early error detecting purpose
func (r *Registrar) logMatchedMiddlewares(ctx context.Context, matched []MiddlewareMapping, group, path string, methods []string) {
// for now, we only warn about duplicates
seen := map[string][]MiddlewareMapping{}
for _, mw := range matched {
if mw.Name() == "" {
continue
}
v := seen[mw.Name()]
v = append(v, mw)
seen[mw.Name()] = v
}
var dups []string
for k, v := range seen {
if len(v) <= 1 || r.warnExclusion.Has(path) {
continue
}
dups = append(dups, fmt.Sprintf(`"%s"x%d`, k, len(v)))
}
if len(dups) > 0 {
if group == "/" {
group = ""
}
logger.WithContext(ctx).Warnf("multiple Middlewares with same name detected at %s%s %v: %v", group, path, methods, dups)
}
return
}
/***************************
helper types
***************************/
type routedMappings map[string]groupsMap
func (m routedMappings) GetOrNew(key string) groupsMap {
if v, ok := m[key]; ok {
return v
}
v := groupsMap{}
m[key] = v
return v
}
type groupsMap map[string]pathsMap
func (m groupsMap) GetOrNew(key string) pathsMap {
if v, ok := m[key]; ok {
return v
}
v := pathsMap{}
m[key] = v
return v
}
type pathsMap map[string][]RoutedMapping
func (m pathsMap) GetOrNew(key string) []RoutedMapping {
if v, ok := m[key]; ok {
return v
}
v := make([]RoutedMapping, 0)
m[key] = v
return v
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package rest
import (
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/internal/mvc"
"net/http"
)
// EndpointFunc is a function with following signature
// - one or two input parameters with the 1st as context.Context and the 2nd as <request>
// - at least two output parameters with the 2nd last as <response> and the last as error
//
// where
// <request>: a struct or a pointer to a struct whose fields are properly tagged
// <response>: supported types are (will support more in the future):
// - a struct or a pointer to a struct whose fields are properly tagged.
// - interface{}, if decoding is not supported (rest not used by any go client)
// - map[string]interface{}
// - string
// - []byte
//
// e.g.: func(context.Context, request *AnyStructWithTag) (response *AnyStructWithTag, error) {...}
type EndpointFunc interface{}
// MappingBuilder builds web.EndpointMapping using web.GinBindingRequestDecoder, web.JsonResponseEncoder and web.JsonErrorEncoder
// MappingBuilder.Path, MappingBuilder.Method and MappingBuilder.EndpointFunc are required to successfully build a mapping.
// See EndpointFunc for supported strongly typed function signatures.
// Example:
// <code>
// rest.Put("/path/to/api").EndpointFunc(func...).Build()
// </code>
type MappingBuilder struct {
name string
group string
path string
method string
condition web.RequestMatcher
endpointFunc EndpointFunc
decodeRequestFunc web.DecodeRequestFunc
encodeResponseFunc web.EncodeResponseFunc
encodeErrorFunc web.EncodeErrorFunc
}
func New(names ...string) *MappingBuilder {
var name string
if len(names) > 0 {
name = names[0]
}
return &MappingBuilder{
name: name,
method: web.MethodAny,
}
}
// Convenient Constructors
func Any(path string) *MappingBuilder {
return New().Path(path).Method(web.MethodAny)
}
func Get(path string) *MappingBuilder {
return New().Get(path)
}
func Post(path string) *MappingBuilder {
return New().Post(path)
}
func Put(path string) *MappingBuilder {
return New().Put(path)
}
func Patch(path string) *MappingBuilder {
return New().Patch(path)
}
func Delete(path string) *MappingBuilder {
return New().Delete(path)
}
func Options(path string) *MappingBuilder {
return New().Options(path)
}
func Head(path string) *MappingBuilder {
return New().Head(path)
}
/*****************************
Public
******************************/
func (b *MappingBuilder) Name(name string) *MappingBuilder {
b.name = name
return b
}
func (b *MappingBuilder) Group(group string) *MappingBuilder {
b.group = group
return b
}
func (b *MappingBuilder) Path(path string) *MappingBuilder {
b.path = path
return b
}
func (b *MappingBuilder) Method(method string) *MappingBuilder {
b.method = method
return b
}
func (b *MappingBuilder) Condition(condition web.RequestMatcher) *MappingBuilder {
b.condition = condition
return b
}
func (b *MappingBuilder) EndpointFunc(endpointFunc EndpointFunc) *MappingBuilder {
b.endpointFunc = endpointFunc
return b
}
// Convenient setters
func (b *MappingBuilder) Get(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodGet)
}
func (b *MappingBuilder) Post(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodPost)
}
func (b *MappingBuilder) Put(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodPut)
}
func (b *MappingBuilder) Patch(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodPatch)
}
func (b *MappingBuilder) Delete(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodDelete)
}
func (b *MappingBuilder) Options(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodOptions)
}
func (b *MappingBuilder) Head(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodHead)
}
// Overrides
func (b *MappingBuilder) DecodeRequestFunc(f web.DecodeRequestFunc) *MappingBuilder {
b.decodeRequestFunc = f
return b
}
func (b *MappingBuilder) EncodeResponseFunc(f web.EncodeResponseFunc) *MappingBuilder {
b.encodeResponseFunc = f
return b
}
func (b *MappingBuilder) EncodeErrorFunc(f web.EncodeErrorFunc) *MappingBuilder {
b.encodeErrorFunc = f
return b
}
func (b *MappingBuilder) Build() web.EndpointMapping {
if err := b.validate(); err != nil {
panic(err)
}
return b.buildMapping()
}
/*****************************
Private
******************************/
func (b *MappingBuilder) validate() error {
if b.path == "" && (b.group == "" || b.group == "/") {
return errors.New("empty path")
}
if b.endpointFunc == nil {
return errors.New("missing endpoint function")
}
return nil
}
func (b *MappingBuilder) buildMapping() web.MvcMapping {
if b.method == "" {
b.method = web.MethodAny
}
if b.name == "" {
b.name = fmt.Sprintf("%s %s%s", b.method, b.group, b.path)
}
metadata := mvc.NewFuncMetadata(b.endpointFunc, nil)
decReq := b.decodeRequestFunc
if decReq == nil {
decReq = mvc.GinBindingRequestDecoder(metadata)
}
encResp := b.encodeResponseFunc
if encResp == nil {
encResp = web.JsonResponseEncoder()
}
encErr := b.encodeErrorFunc
if encErr == nil {
encErr = web.JsonErrorEncoder()
}
return web.NewMvcMapping(
b.name, b.group, b.path, b.method, b.condition,
metadata.HandlerFunc(), decReq, encResp, encErr,
)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package template
import (
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/internal/mvc"
"net/http"
"reflect"
)
var supportedResponseTypes = []reflect.Type {
reflect.TypeOf(ModelView{}),
reflect.TypeOf(&ModelView{}),
}
// ModelViewHandlerFunc is a function with following signature
// - two input parameters with 1st as context.Context and 2nd as <request>
// - two output parameters with 1st as <response> and 2nd as error
// where
// <request>: a struct or a pointer to a struct whose fields are properly tagged
// <response>: a pointer to a ModelView.
// e.g.: func(context.Context, request *AnyStructWithTag) (response *ModelView, error) {...}
type ModelViewHandlerFunc interface{}
// MappingBuilder builds web.TemplateMapping using web.GinBindingRequestDecoder, TemplateEncodeResponseFunc and TemplateErrorEncoder
// MappingBuilder.Path, MappingBuilder.Method and MappingBuilder.HandlerFunc are required to successfully build a mapping.
// See ModelViewHandlerFunc for supported strongly typed function signatures.
// Example:
// <code>
// template.Post("/path/to/page").HandlerFunc(func...).Build()
// </code>
type MappingBuilder struct {
name string
group string
path string
method string
condition web.RequestMatcher
handlerFunc ModelViewHandlerFunc
}
func New(names ...string) *MappingBuilder {
var name string
if len(names) > 0 {
name = names[0]
}
return &MappingBuilder{
name: name,
method: web.MethodAny,
}
}
// Convenient Constructors
func Any(path string) *MappingBuilder {
return New().Path(path).Method(web.MethodAny)
}
func Get(path string) *MappingBuilder {
return New().Get(path)
}
func Post(path string) *MappingBuilder {
return New().Post(path)
}
/*****************************
Public
******************************/
func (b *MappingBuilder) Name(name string) *MappingBuilder {
b.name = name
return b
}
func (b *MappingBuilder) Group(group string) *MappingBuilder {
b.group = group
return b
}
func (b *MappingBuilder) Path(path string) *MappingBuilder {
b.path = path
return b
}
func (b *MappingBuilder) Method(method string) *MappingBuilder {
b.method = method
return b
}
func (b *MappingBuilder) Condition(condition web.RequestMatcher) *MappingBuilder {
b.condition = condition
return b
}
func (b *MappingBuilder) HandlerFunc(endpointFunc ModelViewHandlerFunc) *MappingBuilder {
b.handlerFunc = endpointFunc
return b
}
// Convenient setters
func (b *MappingBuilder) Get(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodGet)
}
func (b *MappingBuilder) Post(path string) *MappingBuilder {
return b.Path(path).Method(http.MethodPost)
}
func (b *MappingBuilder) Build() web.TemplateMapping {
if err := b.validate(); err != nil {
panic(err)
}
return b.buildMapping()
}
/*****************************
Private
******************************/
func (b *MappingBuilder) validate() (err error) {
if b.path == "" && (b.group == "" || b.group == "/") {
err = errors.New("empty path")
}
if b.handlerFunc == nil {
err = errors.New("handler func is required for template mapping")
}
return
}
func (b *MappingBuilder) buildMapping() web.MvcMapping {
if b.method == "" {
b.method = web.MethodAny
}
if b.name == "" {
b.name = fmt.Sprintf("%s %s", b.method, b.path)
}
metadata := mvc.NewFuncMetadata(b.handlerFunc, validateHandlerFunc)
decReq := mvc.GinBindingRequestDecoder(metadata)
encResp := TemplateEncodeResponseFunc
return web.NewMvcMapping(b.name, b.group, b.path, b.method, b.condition,
metadata.HandlerFunc(), decReq, encResp, TemplateErrorEncoder)
}
// this is an additional validator, to make sure the response value is supported type
func validateHandlerFunc(f *reflect.Value) error {
if !f.IsValid() || f.IsZero() {
return errors.New("missing ModelViewHandlerFunc")
}
t := f.Type()
// check response type
foundMV := false
OUTER:
for i := t.NumOut() - 1; i >= 0; i-- {
for _, supported := range supportedResponseTypes {
if t.Out(i).ConvertibleTo(supported) {
foundMV = true
break OUTER
}
}
}
switch {
case !foundMV:
return errors.New("ModelViewHandlerFunc need return ModelView or *ModelView")
//more checks if needed
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package template
import (
"context"
"github.com/cisco-open/go-lanai/pkg/web"
"net/http"
"reflect"
)
type RequestContext map[string]interface{}
// MakeRequestContext collect http.Request's exported fields and additional context values
func MakeRequestContext(ctx context.Context, r *http.Request) RequestContext {
rc := RequestContext{}
rval := reflect.ValueOf(r).Elem()
rtype := rval.Type()
for i := rtype.NumField() - 1; i >= 0; i-- {
f := rtype.Field(i)
if f.PkgPath == "" && f.Type.Kind() != reflect.Func {
// TODO we should filter the values
// we only put exported fields
v := rval.Field(i).Interface()
rc[f.Name] = v
}
}
rc["ContextPath"] = web.ContextPath(ctx)
return rc
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package template
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"path"
)
const (
ModelKeyError = "error"
ModelKeyErrorCode = "errorCode"
ModelKeyStatusCode = "statusCode"
ModelKeyStatusText = "statusText"
ModelKeyMessage = "message"
ModelKeySession = "session"
ModelKeyRequestContext = "rc"
ModelKeySecurity = "security"
ModelKeyCsrf = "csrf"
)
var (
viewRedirect = "redirect:"
modelKeyRedirectSC = "redirect.sc"
modelKeyRedirectLoc = "redirect.location"
modelKeyIgnoreCtxPath = "redirect.noCtxPath"
)
type Model gin.H
type ModelView struct {
// View is the name of template file
View string
// Model is map[string]interface{}
Model Model
}
type ModelValuer interface{
~func() interface{} | ~func(ctx context.Context) interface{} | ~func(req *http.Request) interface{}
}
func StaticModelValuer(value interface{}) func() interface{} {
return func() interface{} {
return value
}
}
func ContextModelValuer[T any](fn func(ctx context.Context) T) func(context.Context) interface{} {
return func(ctx context.Context) interface{} {
return fn(ctx)
}
}
func RequestModelValuer[T any](fn func(req *http.Request) T) func(req *http.Request) interface{} {
return func(req *http.Request) interface{} {
return fn(req)
}
}
var globalModelValuers = map[string]interface{}{}
// RegisterGlobalModelValuer register a ModelValuer with given model key. The registered ModelValuer is applied
// before any ModelView is rendered.
// Use StaticModelValuer, ContextModelValuer or RequestModelValuer to wrap values/functions as ModelValuer
func RegisterGlobalModelValuer[T ModelValuer](key string, valuer T) {
globalModelValuers[key] = any(valuer)
}
func RedirectView(location string, statusCode int, ignoreContextPath bool) *ModelView {
if statusCode < 300 || statusCode > 399 {
statusCode = http.StatusFound
}
return &ModelView{
View: viewRedirect,
Model: Model{
modelKeyRedirectSC: statusCode,
modelKeyRedirectLoc: location,
modelKeyIgnoreCtxPath: ignoreContextPath,
},
}
}
func isRedirect(mv *ModelView) (ret bool) {
if mv.View != viewRedirect {
return
}
if _, ok := mv.Model[modelKeyRedirectLoc].(string); !ok {
return
}
if _, ok := mv.Model[modelKeyRedirectSC].(int); !ok {
return
}
return true
}
func redirect(ctx context.Context, mv *ModelView) (int, string) {
sc, _ := mv.Model[modelKeyRedirectSC].(int)
location := mv.Model[modelKeyRedirectLoc].(string)
loc, e := url.Parse(location)
if e != nil {
return sc, location
}
ignoreCtxPath, _ := mv.Model[modelKeyIgnoreCtxPath].(bool)
if loc.IsAbs() || ignoreCtxPath {
return sc, loc.String()
}
ctxPath := web.ContextPath(ctx)
loc.Path = path.Join(ctxPath, loc.Path)
return sc, loc.String()
}
/**********************************
Response Encoder
***********************************/
//goland:noinspection GoNameStartsWithPackageName
func TemplateEncodeResponseFunc(ctx context.Context, _ http.ResponseWriter, response interface{}) error {
gc := web.GinContext(ctx)
if gc == nil {
return errors.New("unable to use template: context is not available")
}
// get status code
status := 200
if coder, ok := response.(web.StatusCoder); ok {
status = coder.StatusCode()
}
if entity, ok := response.(web.BodyContainer); ok {
response = entity.Body()
}
var mv *ModelView
switch v := response.(type) {
case *ModelView:
mv = v
case ModelView:
mv = &v
default:
return errors.New("unable to use template: response is not *template.ModelView")
}
switch {
case isRedirect(mv):
gc.Redirect(redirect(ctx, mv))
default:
AddGlobalModelData(ctx, mv.Model, gc.Request)
gc.HTML(status, mv.View, mv.Model)
}
return nil
}
/*****************************
JSON Error Encoder
******************************/
//nolint:errorlint
//goland:noinspection GoNameStartsWithPackageName
func TemplateErrorEncoder(c context.Context, err error, w http.ResponseWriter) {
if headerer, ok := err.(web.Headerer); ok {
for k, values := range headerer.Headers() {
for _, v := range values {
w.Header().Add(k, v)
}
}
}
code := http.StatusInternalServerError
if sc, ok := err.(web.StatusCoder); ok {
code = sc.StatusCode()
}
gc := web.GinContext(c)
if gc == nil {
w.Header().Set(web.HeaderContentType, "text/plain; charset=utf-8")
w.WriteHeader(code)
_, _ = w.Write([]byte(err.Error()))
return
}
model := Model{
ModelKeyError: err,
ModelKeyMessage: err.Error(),
ModelKeyStatusCode: code,
ModelKeyStatusText: http.StatusText(code),
}
AddGlobalModelData(c, model, gc.Request)
gc.HTML(code, web.ErrorTemplate, model)
}
func AddGlobalModelData(ctx context.Context, model Model, r *http.Request) {
model[ModelKeyRequestContext] = MakeRequestContext(ctx, r)
applyGlobalModelValuers(ctx, r, model)
}
func applyGlobalModelValuers(ctx context.Context, r *http.Request, model Model) {
for k, valuer := range globalModelValuers {
var v interface{}
switch fn := valuer.(type) {
case func() interface{}:
v = fn()
case func(ctx context.Context) interface{}:
v = fn(ctx)
case func(req *http.Request) interface{}:
v = fn(r)
}
if v != nil {
model[k] = v
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package webtracing
import (
"context"
"github.com/cisco-open/go-lanai/pkg/tracing"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/gin-gonic/gin"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"net/http"
)
const opName = "http"
var (
healthMatcher = matcher.RequestWithPattern("**/health")
corsPreflightMatcher = matcher.RequestWithMethods(http.MethodOptions)
excludeRequest = corsPreflightMatcher.Or(healthMatcher)
)
type tracingWebCustomizer struct {
tracer opentracing.Tracer
}
func newTracingWebCustomizer(tracer opentracing.Tracer) *tracingWebCustomizer {
return &tracingWebCustomizer{
tracer: tracer,
}
}
// Order we want tracingWebCustomizer before anything else
func (c tracingWebCustomizer) Order() int {
return order.Highest
}
func (c tracingWebCustomizer) Customize(_ context.Context, r *web.Registrar) error {
//nolint:contextcheck // false positive
if e := r.AddGlobalMiddlewares(GinTracing(c.tracer, opName, excludeRequest)); e != nil {
return e
}
return nil
}
func GinTracing(tracer opentracing.Tracer, opName string, excludes web.RequestMatcher) gin.HandlerFunc {
return func(gc *gin.Context) {
if m, e := excludes.Matches(gc.Request); e == nil && m {
return
}
// start or join span
orig := gc.Request.Context()
ctx := contextWithRequest(orig, tracer, gc.Request, opName)
gc.Request = gc.Request.WithContext(ctx)
gc.Next()
// finish the span
tracing.WithTracer(tracer).
WithOptions(tracing.SpanHttpStatusCode(gc.Writer.Status())).
Finish(ctx)
gc.Request = gc.Request.WithContext(orig)
}
}
/*********************
common funcs
*********************/
func opNameWithRequest(opName string, r *http.Request) string {
return opName + " " + r.URL.Path
}
func contextWithRequest(ctx context.Context, tracer opentracing.Tracer, req *http.Request, opName string) context.Context {
opName = opNameWithRequest(opName, req)
spanOp := tracing.WithTracer(tracer).
WithOpName(opName).
WithOptions(
tracing.SpanKind(ext.SpanKindRPCServerEnum),
tracing.SpanHttpMethod(req.Method),
tracing.SpanHttpUrl(req.URL.String()),
)
if spanCtx, e := tracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)); e == nil {
spanOp = spanOp.WithStartOptions(ext.RPCServerOption(spanCtx))
}
return spanOp.NewSpanOrDescendant(ctx)
}
package webtracing
import (
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/opentracing/opentracing-go"
"go.uber.org/fx"
)
var Module = &bootstrap.Module{
Name: "web-tracing",
Precedence: web.MinWebPrecedence,
PriorityOptions: []fx.Option{
fx.Invoke(setup),
},
}
type initDI struct {
fx.In
Registrar *web.Registrar `optional:"true"`
Tracer opentracing.Tracer `optional:"true"`
}
func setup(di initDI) {
if di.Tracer != nil && di.Registrar != nil {
di.Registrar.MustRegister(newTracingWebCustomizer(di.Tracer))
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package web
import (
"fmt"
"github.com/gin-gonic/gin/binding"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
)
var (
bindingValidator = newValidator(binding.Validator)
)
// Validator returns the global validator for binding.
// Callers can register custom validators
func Validator() *Validate {
return bindingValidator
}
func newValidator(ginValidator binding.StructValidator) *Validate {
validate := ginValidator.Engine().(*validator.Validate)
return &Validate{
Validate: validate,
}
}
// Validate is a thin wrapper around validator/v10, which prevent modifying TagName
type Validate struct {
*validator.Validate
}
// WithTagName create a shallow copy of internal validator.Validate with different tag name
func (v *Validate) WithTagName(name string) *Validate {
cp := Validate{
Validate: v.Validate,
}
cp.Validate.SetTagName(name)
return &cp
}
func (v *Validate) SetTagName(name string) {
panic(fmt.Errorf("illegal attempt to modify tag of validator. Please use WithTagName(string)"))
}
// SetTranslations registers default translations using given regFn
func (v *Validate) SetTranslations(trans ut.Translator, regFn func(*validator.Validate, ut.Translator) error) error {
return regFn(v.Validate, trans)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/rest"
)
type Controller struct{}
func NewController() Controller {
return Controller{}
}
func (c Controller) Mappings() []web.Mapping {
return []web.Mapping{
rest.Post("/basic/:var").EndpointFunc(StructPtr200).Build(),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"net/url"
"strconv"
)
type JsonRequest struct {
UriVar string `uri:"var"`
QueryVar string `form:"q"`
HeaderVar string `header:"X-VAR"`
JsonString string `json:"string"`
JsonInt int `json:"int"`
}
type Response struct {
UriVar string `json:"uri"`
QueryVar string `json:"q"`
HeaderVar string `json:"header"`
JsonString string `json:"string"`
JsonInt int `json:"int"`
}
func newResponse(req *JsonRequest) *Response {
return &Response{
UriVar: req.UriVar,
QueryVar: req.QueryVar,
HeaderVar: req.HeaderVar,
JsonString: req.JsonString,
JsonInt: req.JsonInt,
}
}
type JsonResponse Response
func newJsonResponse(req *JsonRequest) *JsonResponse {
return (*JsonResponse)(newResponse(req))
}
type TextResponse Response
func newTextResponse(req *JsonRequest) *TextResponse {
return (*TextResponse)(newResponse(req))
}
func (r TextResponse) MarshalText() ([]byte, error) {
values := url.Values{}
values.Set("uri", r.UriVar)
values.Set("q", r.QueryVar)
values.Set("header", r.HeaderVar)
values.Set("string", r.JsonString)
values.Set("int", strconv.Itoa(r.JsonInt))
return []byte(values.Encode()), nil
}
type BytesResponse Response
func newBytesResponse(req *JsonRequest) *BytesResponse {
return (*BytesResponse)(newResponse(req))
}
func (r BytesResponse) MarshalBinary() ([]byte, error) {
return TextResponse(r).MarshalText()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin/binding"
"net/http"
)
/*********************
Supported
*********************/
func StructPtr200(_ context.Context, req *JsonRequest) (*JsonResponse, error) {
return newJsonResponse(req), nil
}
func Struct200(_ context.Context, req JsonRequest) (JsonResponse, error) {
return *newJsonResponse(&req), nil
}
func StructPtr201(_ context.Context, req *JsonRequest) (int, *JsonResponse, error) {
return http.StatusCreated, newJsonResponse(req), nil
}
func Struct201(_ context.Context, req JsonRequest) (int, JsonResponse, error) {
return http.StatusCreated, *newJsonResponse(&req), nil
}
func StructPtr201WithHeader(_ context.Context, req *JsonRequest) (http.Header, int, *JsonResponse, error) {
header := http.Header{}
header.Set("X-VAR", req.HeaderVar)
return header, http.StatusCreated, newJsonResponse(req), nil
}
func Struct201WithHeader(_ context.Context, req JsonRequest) (http.Header, int, JsonResponse, error) {
header := http.Header{}
header.Set("X-VAR", req.HeaderVar)
return header, http.StatusCreated, *newJsonResponse(&req), nil
}
func Raw(ctx context.Context, req *http.Request) (interface{}, error) {
gc := web.GinContext(ctx)
var jsonReq JsonRequest
_ = gc.BindUri(&jsonReq)
_ = binding.Query.Bind(req, &jsonReq)
_ = binding.Header.Bind(req, &jsonReq)
_ = binding.JSON.Bind(req, &jsonReq)
return newJsonResponse(&jsonReq), nil
}
func NoRequest(_ context.Context) (*JsonResponse, error) {
return &JsonResponse{}, nil
}
func Text(_ context.Context, req *JsonRequest) (*TextResponse, error) {
return newTextResponse(req), nil
}
func TextString(_ context.Context, req *JsonRequest) (string, error) {
resp := newTextResponse(req)
bytes, e := resp.MarshalText()
return string(bytes), e
}
func TextBytes(_ context.Context, req *JsonRequest) ([]byte, error) {
resp := newTextResponse(req)
return resp.MarshalText()
}
func Bytes(_ context.Context, req *JsonRequest) ([]byte, error) {
resp := newBytesResponse(req)
return resp.MarshalBinary()
}
func BytesStruct(_ context.Context, req *JsonRequest) (*BytesResponse, error) {
return newBytesResponse(req), nil
}
func BytesString(_ context.Context, req *JsonRequest) (string, error) {
resp := newBytesResponse(req)
bytes, e := resp.MarshalBinary()
return string(bytes), e
}
/*********************
Not Supported
*********************/
func MissingResponse(_ context.Context, _ *JsonRequest) error {
return nil
}
func MissingError(_ context.Context, _ *JsonRequest) *JsonResponse {
return nil
}
func MissingContext(_ *JsonRequest) (*JsonResponse, error) {
return nil, nil
}
func WrongErrorPosition(_ context.Context, _ *JsonRequest) (error, *JsonResponse, int) {
return nil, nil, 0
}
func WrongContextPosition( _ *JsonRequest, _ context.Context) (*JsonResponse, int, error) {
return nil, 0, nil
}
func ExtraInput(_ context.Context, _ *JsonRequest, _ string) (*JsonResponse, error) {
return nil, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"dario.cat/mergo"
"fmt"
"github.com/cisco-open/go-lanai/pkg/web/template"
"net/http"
"reflect"
"strings"
)
func IndexPage(_ context.Context, _ *http.Request) (template.ModelView, error) {
return template.ModelView{
View: "index.html.tmpl",
Model: template.Model{
"Title": "TemplateMVCTest",
},
}, nil
}
func RedirectPage(_ context.Context, _ *http.Request) (*template.ModelView, error) {
return template.RedirectView("/index", http.StatusFound, false), nil
}
const ModelPrintTmpl = `%s=%v`
// PrintKV is a template function
func PrintKV(model map[string]any) string {
lines := flattenMap(model, "")
return strings.Join(lines, "\n")
}
func flattenMap[T any](m map[string]T, prefix string) []string {
lines := make([]string, 0, len(m))
for k, val := range m {
var unknown interface{} = val
switch v := unknown.(type) {
case template.RequestContext:
lines = append(lines, flattenMap(v, prefix+"."+k)...)
case map[string]any:
lines = append(lines, flattenMap(v, prefix+"."+k)...)
case map[string]string:
lines = append(lines, flattenMap(v, prefix+"."+k)...)
case []any:
for i := range v {
k = fmt.Sprintf(`%s.%d`, prefix, i)
lines = append(lines, fmt.Sprintf(ModelPrintTmpl, k, v[i]))
}
case []string:
for i := range v {
k = fmt.Sprintf(`%s.%d`, prefix, i)
lines = append(lines, fmt.Sprintf(ModelPrintTmpl, k, v[i]))
}
case fmt.Stringer, fmt.GoStringer, error:
k = fmt.Sprintf(`%s.%s`, prefix, k)
lines = append(lines, fmt.Sprintf(ModelPrintTmpl, k, v))
case nil:
// do nothing
default:
var converted map[string]interface{}
switch reflect.Indirect(reflect.ValueOf(v)).Kind() {
case reflect.Struct:
converted = map[string]interface{}{}
_ = mergo.Map(&converted, v)
default:
}
if len(converted) != 0 {
lines = append(lines, flattenMap(converted, prefix+"."+k)...)
} else {
k = fmt.Sprintf(`%s.%s`, prefix, k)
lines = append(lines, fmt.Sprintf(ModelPrintTmpl, k, v))
}
}
}
return lines
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package weberror
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
)
type MappingBuilder struct {
name string
matcher web.RouteMatcher
order int
condition web.RequestMatcher
translateFunc web.ErrorTranslateFunc
}
func New(name ...string) *MappingBuilder {
n := "anonymous"
if len(name) != 0 {
n = name[0]
}
return &MappingBuilder{
name: n,
matcher: matcher.AnyRoute(),
}
}
/*****************************
Public
******************************/
func (b *MappingBuilder) Name(name string) *MappingBuilder {
b.name = name
return b
}
func (b *MappingBuilder) Order(order int) *MappingBuilder {
b.order = order
return b
}
func (b *MappingBuilder) With(translator web.ErrorTranslator) *MappingBuilder {
b.translateFunc = translator.Translate
return b
}
func (b *MappingBuilder) ApplyTo(matcher web.RouteMatcher) *MappingBuilder {
b.matcher = matcher
return b
}
func (b *MappingBuilder) Use(translateFunc web.ErrorTranslateFunc) *MappingBuilder {
b.translateFunc = translateFunc
return b
}
func (b *MappingBuilder) WithCondition(condition web.RequestMatcher) *MappingBuilder {
b.condition = condition
return b
}
func (b *MappingBuilder) Build() web.ErrorTranslateMapping {
if b.matcher == nil {
b.matcher = matcher.AnyRoute()
}
if b.name == "" {
b.name = fmt.Sprintf("%v", b.matcher)
}
if b.translateFunc == nil {
panic(fmt.Errorf("unable to build '%s' error translation mapping: error translate function is required. please use With(...) or Use(...)", b.name))
}
return web.NewErrorTranslateMapping(b.name, b.order, b.matcher, b.condition, b.translateFunc)
}
/*****************************
Helpers
******************************/
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuatortest
import (
. "github.com/cisco-open/go-lanai/test/utils/gomega"
"github.com/onsi/gomega"
. "github.com/onsi/gomega"
"io"
"net/http"
"testing"
)
// AssertEnvResponse fail the test if the response doesn't contain "test" profile.
// This function only support V3 response.
func AssertEnvResponse(t *testing.T, resp *http.Response) {
g := gomega.NewWithT(t)
body, e := io.ReadAll(resp.Body)
g.Expect(e).To(Succeed(), `env response body should be readable`)
g.Expect(body).To(HaveJsonPathWithValue("$.activeProfiles[0]", "test"), "env response should contains correct active profiles")
g.Expect(body).To(HaveJsonPath("$.propertySources"), "env response should contains propertySources")
g.Expect(body).To(HaveJsonPath("$.propertySources[0]"), "env response should contains non-empty propertySources")
}
// AssertAPIListResponse fail the test if the response doesn't contain any "endpoint".
// This function only support V3 response.
func AssertAPIListResponse(t *testing.T, resp *http.Response) {
g := gomega.NewWithT(t)
body, e := io.ReadAll(resp.Body)
g.Expect(e).To(Succeed(), `apilist response body should be readable`)
g.Expect(body).To(HaveJsonPath("$..endpoint"), "apilist response should contain some endpoint field")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuatortest
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
. "github.com/cisco-open/go-lanai/test/utils/gomega"
"github.com/onsi/gomega"
. "github.com/onsi/gomega"
"io"
"net/http"
"testing"
)
type ExpectedHealthOptions func(h *ExpectedHealth)
type ExpectedHealth struct {
Status health.Status
HasDetails bool
HasComponents bool
RequiredComponents []string
}
func ExpectHealth(status health.Status) ExpectedHealthOptions {
return func(h *ExpectedHealth) {
h.Status = status
}
}
func ExpectHealthComponents(requiredComps ...string) ExpectedHealthOptions {
return func(h *ExpectedHealth) {
h.HasComponents = true
h.RequiredComponents = requiredComps
}
}
func ExpectHealthDetails() ExpectedHealthOptions {
return func(h *ExpectedHealth) {
h.HasDetails = true
}
}
// AssertHealthResponse fail the test if given response is not a correct "health" endpoint response.
// By default, this function expect a simple health response with status UP and no details nor components disclosed.
// This function support both V2 and V3 responses, default to V3
func AssertHealthResponse(t *testing.T, resp *http.Response, expectations ...ExpectedHealthOptions) {
expected := ExpectedHealth{
Status: health.StatusUp,
}
for _, fn := range expectations {
fn(&expected)
}
g := gomega.NewWithT(t)
// determine response versions
switch typ := resp.Header.Get("Content-Type"); typ {
case actuator.ContentTypeSpringBootV2:
assertHealthResponseV2(t, g, resp, &expected)
default:
assertHealthResponseV3(t, g, resp, &expected)
}
}
func assertHealthResponseV3(_ *testing.T, g *gomega.WithT, resp *http.Response, exp *ExpectedHealth) {
const jsonPathComponents = "$..components"
const jsonPathDetails = "$..details"
body, e := io.ReadAll(resp.Body)
g.Expect(e).To(Succeed(), `health response body should be readable`)
g.Expect(body).To(HaveJsonPathWithValue("$.status", exp.Status.String()), "health response should have status [%v]", exp.Status)
if exp.HasComponents {
g.Expect(body).To(HaveJsonPath(jsonPathComponents), "v3 health response should have components")
for _, comps := range exp.RequiredComponents {
jsonPath := fmt.Sprintf("$.components.%s", comps)
g.Expect(body).To(HaveJsonPath(jsonPath), "v3 health response should have '%s' status", comps)
}
} else {
g.Expect(body).NotTo(HaveJsonPath(jsonPathComponents), "v3 health response should not have components")
}
if exp.HasDetails {
g.Expect(body).To(HaveJsonPath(jsonPathDetails), "v3 health response should have details")
} else {
g.Expect(body).NotTo(HaveJsonPath(jsonPathDetails), "v3 health response should not have details")
}
}
func assertHealthResponseV2(_ *testing.T, g *gomega.WithT, resp *http.Response, exp *ExpectedHealth) {
const jsonPathComponents = "$..details"
const jsonPathDetails = "$..detailed"
body, e := io.ReadAll(resp.Body)
g.Expect(e).To(Succeed(), `health response body should be readable`)
g.Expect(body).To(HaveJsonPathWithValue("$.status", ContainElement(exp.Status.String())), "health response should have status [%v]", exp.Status)
if exp.HasComponents {
g.Expect(body).To(HaveJsonPath(jsonPathComponents), "v2 health response should have components")
for _, comps := range exp.RequiredComponents {
jsonPath := fmt.Sprintf("$.details.%s", comps)
g.Expect(body).To(HaveJsonPath(jsonPath), "v2 health response should have '%s' status", comps)
}
} else {
g.Expect(body).NotTo(HaveJsonPath(jsonPathComponents), "v2 health response should not have components")
}
if exp.HasDetails {
g.Expect(body).To(HaveJsonPath(jsonPathDetails), "v2 health response should have details")
} else {
g.Expect(body).NotTo(HaveJsonPath(jsonPathDetails), "v2 health response should not have details")
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuatortest
import (
"fmt"
. "github.com/cisco-open/go-lanai/test/utils/gomega"
"github.com/onsi/gomega"
. "github.com/onsi/gomega"
"io"
"net/http"
"testing"
)
type ExpectedLoggersOptions func(l *ExpectedLoggers)
type ExpectedLoggers struct {
Single bool
EffectiveLevels map[string]interface{}
ConfiguredLevels map[string]interface{}
}
// ExpectLoggersSingleEntry expects logger response is a single entry.
// Maximum of two "expected levels" are supported:
// - 1st level is expected "effective level".
// If nil or missing, it's expected to be any value.
// - 2nd level is expected "configured level".
// If nil, it's expected to be any value.
// If missing, it's not checked at all
// Note: "effectiveLevel" is expected always available in any "loggers" response.
func ExpectLoggersSingleEntry(expectedLevels...interface{}) ExpectedLoggersOptions {
return func(l *ExpectedLoggers) {
l.Single = true
l.EffectiveLevels = map[string]interface{}{}
l.ConfiguredLevels = map[string]interface{}{}
// Note, logger name doesn't matter in "single-entry"
if len(expectedLevels) > 0 {
l.EffectiveLevels["logger"] = expectedLevels[0]
}
if len(expectedLevels) > 1 {
l.ConfiguredLevels["logger"] = expectedLevels[1]
}
}
}
func ExpectLoggersEffectiveLevels(kvs...string) ExpectedLoggersOptions {
return func(l *ExpectedLoggers) {
setKVs(l.EffectiveLevels, kvs)
}
}
func ExpectLoggersConfiguredLevels(kvs...string) ExpectedLoggersOptions {
return func(l *ExpectedLoggers) {
setKVs(l.ConfiguredLevels, kvs)
}
}
func setKVs(kvMap map[string]interface{}, kvs []string) {
for i := range kvs {
var v string
if i + 1 < len(kvs) {
v = kvs[i+1]
}
kvMap[kvs[i]] = v
}
}
// AssertLoggersResponse fail the test if the response is not the response of "loggers" endpoint.
// By default, this function expects:
// - The response includes all loggers with effective level and all supported levels.
// This function only support V3 response.
func AssertLoggersResponse(t *testing.T, resp *http.Response, expectations ...ExpectedLoggersOptions) {
expected := ExpectedLoggers{
EffectiveLevels: map[string]interface{}{},
ConfiguredLevels: map[string]interface{}{},
}
for _, fn := range expectations {
fn(&expected)
}
g := gomega.NewWithT(t)
switch {
case expected.Single:
assertSingleLoggerResponse(t, g, resp, &expected)
default:
assertLoggersResponse(t, g, resp, &expected)
}
}
func assertLoggersResponse(t *testing.T, g *WithT, resp *http.Response, expected *ExpectedLoggers) {
body, e := io.ReadAll(resp.Body)
g.Expect(e).To(Succeed(), `loggers response body should be readable`)
g.Expect(body).To(HaveJsonPath("$.levels"), "loggers response should contains 'levels'")
g.Expect(body).To(HaveJsonPath("$.loggers"), "loggers response should contains 'loggers'")
g.Expect(body).To(HaveJsonPath("$.loggers[*].effectiveLevel"), "loggers response should contains 'effectiveLevel'")
assertLogLevels(t, g, body, expected, func(name string) string {
return fmt.Sprintf(`$.loggers["%s"]`, name)
})
}
func assertSingleLoggerResponse(t *testing.T, g *WithT, resp *http.Response, expected *ExpectedLoggers) {
body, e := io.ReadAll(resp.Body)
g.Expect(e).To(Succeed(), `loggers response body should be readable`)
g.Expect(body).To(HaveJsonPath("$.effectiveLevel"), "loggers response should contains 'effectiveLevel'")
assertLogLevels(t, g, body, expected, func(_ string) string {
return "$"
})
}
func assertLogLevels(_ *testing.T, g *WithT, body []byte, expected *ExpectedLoggers, loggerJsonPathFn func(name string) string) {
for k, v := range expected.EffectiveLevels {
jsonPath := loggerJsonPathFn(k) + ".effectiveLevel"
if v == nil {
g.Expect(body).To(HaveJsonPath(jsonPath), "loggers response should contains logger '%s' with effectiveLevel", k)
} else {
g.Expect(body).To(HaveJsonPathWithValue(jsonPath, v), "loggers response should contains logger '%s' with effectiveLevel=%s", k, v)
}
}
for k, v := range expected.ConfiguredLevels {
jsonPath := loggerJsonPathFn(k) + ".configuredLevel"
if v == nil {
g.Expect(body).To(HaveJsonPath(jsonPath), "loggers response should contains logger '%s' with configuredLevel", k)
} else {
g.Expect(body).To(HaveJsonPathWithValue(jsonPath, v), "loggers response should contains logger '%s' with configuredLevel=%s", k, v)
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuatortest
type ActuatorOptions func(opt *ActuatorOption)
type ActuatorOption struct {
// Default to false. When set true, the default health, info and env endpoints are not initialized
DisableAllEndpoints bool
// Default to true. When set to false, the default authentication is installed.
// Depending on the defualt authentication (currently tokenauth), more dependencies might be needed
DisableDefaultAuthentication bool
}
// DisableAllEndpoints is an ActuatorOptions that disable all endpoints in test.
// Any endpoint need to be installed manually via apptest.WithModules(...)
func DisableAllEndpoints() ActuatorOptions {
return func(opt *ActuatorOption) {
opt.DisableAllEndpoints = true
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package actuatortest
import (
"github.com/cisco-open/go-lanai/pkg/actuator"
"github.com/cisco-open/go-lanai/pkg/actuator/env"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
healthep "github.com/cisco-open/go-lanai/pkg/actuator/health/endpoint"
"github.com/cisco-open/go-lanai/pkg/actuator/info"
actuatorinit "github.com/cisco-open/go-lanai/pkg/actuator/init"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"go.uber.org/fx"
)
// WithEndpoints is a convenient group of test options that enables actuator endpoints with following configuration
// - "info", "health" and "env" are initialized
// - The default "tokenauth" authentication is disabled. (sectest package can be used to test securities)
// - Uses the default properties and permission based access control. Custom access control can be registered
//
// Note 1: Choice of web testing environment are configured separately.
// See webtest.WithMockedServer() and webtest.WithRealServer()
//
// Note 2: Actuator endpoints usually requires correct properties to be fully functional,
// make sure the test have all "management" properties configured correctly.
//
// Note 3: Additional endpoints can be added by directly adding their Modules in test.
//
// Example:
// test.RunTest(context.Background(), t,
// apptest.Bootstrap(),
// webtest.WithMockedServer(),
// sectest.WithMockedMiddleware(),
// apptest.WithModules(
// // additional endpoints
// loggers.Module,
// ),
// apptest.WithBootstrapConfigFS(testdata.MyTestBootstrapFS),
// apptest.WithConfigFS(testdata.MyTestConfigFS),
// apptest.WithProperties("more.properties: value"...),
// test.GomegaSubTest(SubTestAdminEndpoints(), "MyTests"),
// )
func WithEndpoints(opts ...ActuatorOptions) test.Options {
opt := ActuatorOption{
DisableAllEndpoints: false,
DisableDefaultAuthentication: true,
}
for _, fn := range opts {
fn(&opt)
}
testOpts := []test.Options{
apptest.WithModules(actuatorinit.Module, actuator.Module, errorhandling.Module, access.Module),
}
if !opt.DisableAllEndpoints {
testOpts = append(testOpts, apptest.WithModules(health.Module, healthep.Module, info.Module, env.Module))
}
if opt.DisableDefaultAuthentication {
testOpts = append(testOpts, apptest.WithFxOptions(
fx.Invoke(disableDefaultSecurity),
))
}
return test.WithOptions(testOpts...)
}
// disableDefaultSecurity disable auto-configured "tokenauth" authentication
func disableDefaultSecurity(reg *actuator.Registrar) {
reg.MustRegister(actuator.SecurityCustomizerFunc(func(ws security.WebSecurity) {/* this would override default */}))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testdata
import (
"context"
"github.com/cisco-open/go-lanai/pkg/actuator/health"
)
type MockedHealthIndicator struct {
Status health.Status
Description string
Details map[string]interface{}
}
func NewMockedHealthIndicator() *MockedHealthIndicator {
return &MockedHealthIndicator{
Status: health.StatusUp,
Description: "mocked",
Details: map[string]interface{}{
"key": "value",
},
}
}
func (i *MockedHealthIndicator) Name() string {
return "test"
}
func (i *MockedHealthIndicator) Health(_ context.Context, opts health.Options) health.Health {
ret := health.CompositeHealth{
SimpleHealth: health.SimpleHealth{
Stat: i.Status,
Desc: i.Description,
},
}
if opts.ShowComponents {
detailed := health.DetailedHealth{
SimpleHealth: health.SimpleHealth{
Stat: i.Status,
Desc: "mocked detailed",
},
}
if opts.ShowDetails {
detailed.Details = i.Details
}
ret.Components = map[string]health.Health{
"simple": health.SimpleHealth{
Stat: i.Status,
Desc: "mocked simple",
},
"detailed": detailed,
}
}
return ret
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package apptest
import (
"context"
"embed"
"github.com/cisco-open/go-lanai/pkg/appconfig"
appconfiginit "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/test"
"go.uber.org/fx"
"strings"
)
/*************************
Test Options
*************************/
type PropertyValuerFunc func(ctx context.Context) interface{}
// WithConfigFS provides per-test config capability.
// It register an embed.FS as application config, which could override any defaults.
// the given embed.FS should contains at least one yml file
// see appconfig.FxEmbeddedApplicationAdHoc
func WithConfigFS(fs ...embed.FS) test.Options {
opts := make([]fx.Option, len(fs))
for i, fs := range fs {
opts[i] = appconfiginit.FxEmbeddedApplicationAdHoc(fs, ".", "testdata")
}
return WithFxOptions(opts...)
}
// WithBootstrapConfigFS provides per-test config capability.
// It register an embed.FS as bootstrap config, in which properties like "config.file.search-path" can be overridden.
// the given embed.FS should contains at least one yml file.
// see appconfig.FxEmbeddedBootstrapAdHoc
func WithBootstrapConfigFS(fs ...embed.FS) test.Options {
opts := make([]fx.Option, len(fs))
for i, fs := range fs {
opts[i] = appconfiginit.FxEmbeddedBootstrapAdHoc(fs, ".", "testdata")
}
return WithFxPriorityOptions(opts...)
}
// WithProperties provides per-test config capability.
// It registers ad-hoc test application properties. Supported format of each Key-Value pair are:
// - "dotted.properties=value"
// - "dotted.properties: value"
// - "dotted.properties.without.value" implies the value is "true"
func WithProperties(kvs ...string) test.Options {
p := newTestConfigProviderWithKV(kvs)
return WithFxOptions(appconfiginit.FxProvideApplicationAdHoc(p))
}
// WithDynamicProperties provides per-test config capability.
// It registers ad-hoc test application properties
func WithDynamicProperties(valuers map[string]PropertyValuerFunc) test.Options {
kvMap := make(map[string]interface{})
for k, v := range valuers {
kvMap[k]= v
}
p := NewTestConfigProvider(kvMap)
return WithFxOptions(appconfiginit.FxProvideApplicationAdHoc(p))
}
// WithConfigFxProvider provides per-test config capability.
// It takes a fx.Option (usually fx.Provide) that returns/create appconfig.Provider
// and registers it as ad-hoc test application config provider.
// Note: Use it with caution. This is an advanced use case which typically used by other utility packages.
func WithConfigFxProvider(fxProvides ...interface{}) test.Options {
return WithFxOptions(appconfiginit.FxProvideApplicationAdHoc(fxProvides...))
}
/*************************
appconfig.Provider
*************************/
// testConfigProvider implement appconfig.Provider and provide pre-defined functions
type testConfigProvider struct {
appconfig.ProviderMeta
kvs map[string]interface{}
}
// NewTestConfigProvider is for internal usage. Export for cross-package reference
// Use WithConfigFS, WithProperties, WithDynamicProperties instead
func NewTestConfigProvider(kvs map[string]interface{}) *testConfigProvider {
return &testConfigProvider{
ProviderMeta: appconfig.ProviderMeta{Precedence: 0},
kvs: kvs,
}
}
func newTestConfigProviderWithKV(kvs []string) *testConfigProvider {
kvMap := make(map[string]interface{})
for _, e := range kvs {
// we support "a.b.c=v" or "a.b.c: v" or "a.b.c" (implies a.b.c=true)
kv := strings.SplitN(e, "=", 2)
if len(kv) < 2 {
kv = strings.SplitN(e, ":", 2)
}
k := kv[0]
v := "true"
if len(kv) >= 2 {
v = kv[1]
}
kvMap[strings.TrimSpace(k)] = strings.TrimSpace(v)
}
return NewTestConfigProvider(kvMap)
}
func (p *testConfigProvider) Name() string {
return "test-properties"
}
func (p *testConfigProvider) Load(ctx context.Context) (err error) {
defer func() {
p.Loaded = err == nil
}()
flatSettings := make(map[string]interface{})
for k, v := range p.kvs {
switch val := v.(type) {
case string:
flatSettings[k] = utils.ParseString(val)
case PropertyValuerFunc:
flatSettings[k] = val(ctx)
}
}
p.Settings, err = appconfig.UnFlatten(flatSettings)
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package apptest
import (
"context"
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/test"
"github.com/spf13/cobra"
"go.uber.org/fx"
"reflect"
"testing"
"time"
)
//go:embed test-defaults.yml
var TestDefaultConfigFS embed.FS
//go:embed test-bootstrap.yml
var TestBootstrapConfigFS embed.FS
//go:embed test-application.yml
var TestApplicationConfigFS embed.FS
// testBootstrapper holds all configuration to bootstrap a fs-enabled test
type testBootstrapper struct {
bootstrap.Bootstrapper
AppPriorityOptions []fx.Option
AppOptions []fx.Option
}
// Bootstrap is an entrypoint test.Options that indicates all sub tests should be run within the scope of
// an slim version of bootstrap.App
func Bootstrap() test.Options {
return test.WithInternalRunner(NewFxTestRunner())
}
// NewFxTestRunner is internal use only, exported for cross-package reference
func NewFxTestRunner() test.InternalRunner {
return func(ctx context.Context, t *test.T) {
// run setup hooks
ctx = testSetup(ctx, t.T, t.TestHooks)
defer testTeardown(ctx, t.T, t.TestHooks)
// register test module's options without register the module directly
// Note:
// we want to support repeated bootstrap but the bootstrap package doesn't support
// module refresh (caused by singleton pattern).
// Note 4.3:
// Now with help of bootstrap.ExecuteContainedApp(), we are able repeatedly bootstrap a self-contained
// application.
tb, ok := ctx.Value(ctxKeyTestBootstrapper).(*testBootstrapper)
if !ok || tb == nil {
ctx, tb = withTestModule(ctx)
}
// default modules and context
tb.Register(appconfig.Module)
tb.AddInitialAppContextOptions(mergeInitContext(ctx))
// prepare bootstrap fx options
priority := append([]fx.Option{
fx.Supply(t),
appconfig.FxEmbeddedDefaults(TestDefaultConfigFS),
appconfig.FxEmbeddedBootstrapAdHoc(TestBootstrapConfigFS),
appconfig.FxEmbeddedApplicationAdHoc(TestApplicationConfigFS),
}, tb.AppPriorityOptions...)
regular := append([]fx.Option{}, tb.AppOptions...)
// bootstrapping
//nolint:contextcheck // context is not passed on because the bootstrap process is not cancellable. This is a limitation
bootstrap.NewAppCmd("testapp", priority, regular,
func(cmd *cobra.Command) {
cmd.Use = "testapp"
cmd.Args = nil
},
)
tb.EnableCliRunnerMode(newTestCliRunner)
bootstrap.ExecuteContainedApp(ctx, &tb.Bootstrapper)
}
}
func newTestCliRunner(t *test.T) bootstrap.OrderedCliRunner {
return bootstrap.OrderedCliRunner{
// Test runner always run last
Precedence: order.Lowest,
CliRunner: func(ctx context.Context) error {
// run test
test.InternalRunSubTests(ctx, t)
// Note: in case of failed tests, we don't return error. GO's testing framework should be able to figure it out from t.Failed()
return nil
},
}
}
func testSetup(ctx context.Context, t *testing.T, hooks []test.Hook) context.Context {
// run setup hooks
for _, h := range hooks {
var e error
ctx, e = h.Setup(ctx, t)
if e != nil {
t.Fatalf("error when setup test: %v", e)
}
}
return ctx
}
func testTeardown(ctx context.Context, t *testing.T, hooks []test.Hook) {
// register cleanup
for i := len(hooks) - 1; i >= 0; i-- {
if e := hooks[i].Teardown(ctx, t); e != nil {
t.Fatalf("error when setup test: %v", e)
}
}
}
func mergeInitContext(sources ...context.Context) bootstrap.ContextOption {
return func(ctx context.Context) context.Context {
srcs := make([]context.Context, len(sources)+1)
srcs[0] = ctx
for i := range sources {
srcs[i+1] = sources[i]
}
return newMergedContext(srcs...)
}
}
/************************
Init Context
************************/
func newMergedContext(ctxList ...context.Context) context.Context {
done := make(chan struct{})
cases := make([]reflect.SelectCase, len(ctxList))
for i, ctx := range ctxList {
cases[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ctx.Done()),
}
}
go func() {
_, _, _ = reflect.Select(cases)
close(done)
}()
return &mergedContext{
sources: ctxList,
done: done,
}
}
type mergedContext struct {
sources []context.Context
done <-chan struct{}
}
func (mc mergedContext) Deadline() (earliest time.Time, ok bool) {
for _, ctx := range mc.sources {
if deadline, subOk := ctx.Deadline(); subOk && (earliest.IsZero() || deadline.Before(earliest)) {
earliest = deadline
ok = true
}
}
return
}
func (mc mergedContext) Done() <-chan struct{} {
return mc.done
}
func (mc mergedContext) Err() error {
for _, ctx := range mc.sources {
if err := ctx.Err(); err != nil {
return err
}
}
return nil
}
func (mc mergedContext) Value(key any) any {
for _, ctx := range mc.sources {
if v := ctx.Value(key); v != nil {
return v
}
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package apptest
import (
"context"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/test"
"go.uber.org/fx"
"testing"
"time"
)
// WithDI populate given di targets by using fx.Populate
// all targets need to be pointer to struct, otherwise the test fails
// See fx.Populate for more information
func WithDI(diTargets ...interface{}) test.Options {
return WithFxOptions(fx.Populate(diTargets...))
}
// WithModules register given modules to test app
func WithModules(modules ...*bootstrap.Module) test.Options {
return test.Setup(func(ctx context.Context, t *testing.T) (context.Context, error) {
ret, tb := withTestModule(ctx)
for _, m := range modules {
tb.Register(m)
}
return ret, nil
})
}
// WithTimeout specify expected test timeout to prevent blocking test process permanently
func WithTimeout(timeout time.Duration) test.Options {
return WithFxOptions(fx.StartTimeout(timeout))
}
// WithFxOptions register given fx.Option to test app as regular steps
// see bootstrap.Module
func WithFxOptions(opts ...fx.Option) test.Options {
return test.Setup(func(ctx context.Context, t *testing.T) (context.Context, error) {
ret, tb := withTestModule(ctx)
tb.AppOptions = append(tb.AppOptions, opts...)
return ret, nil
})
}
// WithFxPriorityOptions register given fx.Option to test app as priority steps, before any other modules
// see bootstrap.Module
func WithFxPriorityOptions(opts ...fx.Option) test.Options {
return test.Setup(func(ctx context.Context, t *testing.T) (context.Context, error) {
ret, tb := withTestModule(ctx)
tb.AppPriorityOptions = append(tb.AppPriorityOptions, opts...)
return ret, nil
})
}
func withTestModule(ctx context.Context) (context.Context, *testBootstrapper) {
ret := ctx
tb, ok := ctx.Value(ctxKeyTestBootstrapper).(*testBootstrapper)
if !ok || tb == nil {
tb = &testBootstrapper{
Bootstrapper: *bootstrap.NewBootstrapper(),
}
ret = &testFxContext{
Context: ctx,
tb: tb,
}
}
return ret, tb
}
/*********************
Test FX Context
*********************/
type testBootstrapperCtxKey struct{}
var ctxKeyTestBootstrapper = testBootstrapperCtxKey{}
type testFxContext struct {
context.Context
tb *testBootstrapper
}
func (c *testFxContext) Value(key interface{}) interface{} {
switch {
case key == ctxKeyTestBootstrapper:
return c.tb
}
return c.Context.Value(key)
}
// TestBootstrapper returns current *bootstrap.Bootstrapper of the test context
func TestBootstrapper(ctx context.Context) *bootstrap.Bootstrapper {
tb, ok := ctx.Value(ctxKeyTestBootstrapper).(*testBootstrapper)
if ok {
return &tb.Bootstrapper
}
panic("TestBootstrapper is used without apptest.Bootstrap()")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package consultest
// Leveraging ittest package and HTTP VCR to record and replay consul operations
package consultest
import (
"github.com/cisco-open/go-lanai/pkg/consul"
consulinit "github.com/cisco-open/go-lanai/pkg/consul/init"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"github.com/cisco-open/go-lanai/test/ittest"
"go.uber.org/fx"
"gopkg.in/dnaeon/go-vcr.v3/recorder"
"testing"
)
/*************************
Top-level APIs
*************************/
type ConsulRecorderOptions func(cfg *ConsulRecorderConfig)
type ConsulRecorderConfig struct {
HTTPVCROptions []ittest.HTTPVCROptions
}
func WithHttpPlayback(t *testing.T, opts ...ConsulRecorderOptions) test.Options {
cfg := ConsulRecorderConfig{}
for _, fn := range opts {
fn(&cfg)
}
testOpts := []test.Options{ittest.WithHttpPlayback(t, cfg.HTTPVCROptions...)}
testOpts = append(testOpts,
apptest.WithModules(consulinit.Module),
apptest.WithFxOptions(
fx.Provide(RecordedConsulProvider()),
),
)
return test.WithOptions(testOpts...)
}
/*************************
Top-level Options
*************************/
// HttpRecordingMode enable "recording" mode.
// IMPORTANT: When Record mode is enabled, all sub tests interact with real Consul service.
// So use this mode on LOCAL DEV ONLY
// See ittest.HttpRecordingMode()
func HttpRecordingMode() ConsulRecorderOptions {
return func(cfg *ConsulRecorderConfig) {
cfg.HTTPVCROptions = append(cfg.HTTPVCROptions, ittest.HttpRecordingMode())
}
}
func MoreHTTPVCROptions(opts ...ittest.HTTPVCROptions) ConsulRecorderOptions {
return func(cfg *ConsulRecorderConfig) {
cfg.HTTPVCROptions = append(cfg.HTTPVCROptions, opts...)
}
}
/*************************
Tests Setup Helpers
*************************/
func RecordedConsulProvider() fx.Annotated {
return fx.Annotated{
Group: "consul",
Target: ConsulWithRecorder,
}
}
func ConsulWithRecorder(recorder *recorder.Recorder) consul.Options {
return func(cfg *consul.ClientConfig) error {
switch {
case cfg.Transport != nil:
cfg.HttpClient = recorder.GetDefaultClient()
case cfg.HttpClient != nil:
if cfg.HttpClient.Transport != nil {
recorder.SetRealTransport(cfg.HttpClient.Transport)
}
cfg.HttpClient.Transport = recorder
default:
cfg.HttpClient = recorder.GetDefaultClient()
}
return nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dbtest
import (
"context"
"go.uber.org/fx"
"gorm.io/gorm"
"testing"
)
/*************************
Enums
*************************/
const (
modeAuto mode = iota
modePlayback
modeRecord
)
type mode int
/*************************
DBOptions
*************************/
type DBOptions func(opt *DBOption)
type DBOption struct {
Host string
Port int
DBName string
Username string
Password string
SSL bool
}
func DBName(db string) DBOptions {
return func(opt *DBOption) {
opt.DBName = db
}
}
func DBCredentials(user, password string) DBOptions {
return func(opt *DBOption) {
opt.Username = user
opt.Password = password
}
}
func DBPort(port int) DBOptions {
return func(opt *DBOption) {
opt.Port = port
}
}
func DBHost(host string) DBOptions {
return func(opt *DBOption) {
opt.Host = host
}
}
/*************************
TX context
*************************/
type mockedTxContext struct {
context.Context
}
func (c mockedTxContext) Parent() context.Context {
return c.Context
}
type mockedGormContext struct {
mockedTxContext
db *gorm.DB
}
func (c mockedGormContext) DB() *gorm.DB {
return c.db
}
/*************************
Data Setup
*************************/
type DI struct {
fx.In
DB *gorm.DB
}
type DataSetupStep func(ctx context.Context, t *testing.T, db *gorm.DB) context.Context
type DataSetupScope func(ctx context.Context, t *testing.T, db *gorm.DB) (context.Context, *gorm.DB)
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dbtest
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"regexp"
"strings"
"sync"
"testing"
"github.com/cisco-open/go-lanai/test"
"github.com/ghodss/yaml"
"github.com/onsi/gomega"
. "github.com/onsi/gomega"
"gorm.io/gorm"
)
// PrepareData is a convenient function that returns a test.SetupFunc that executes given DataSetupStep in provided order
// Note: PrepareData accumulate all changes applied to context
func PrepareData(di *DI, steps ...DataSetupStep) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
for _, fn := range steps {
ctx = fn(ctx, t, di.DB)
if t.Failed() {
return ctx, errors.New("test failed during data preparation")
}
}
return ctx, nil
}
}
// PrepareDataWithScope is similar to PrepareData, it applies given DataSetupScope before executing all DataSetupStep.
// DataSetupScope is used to prepare context and gorm.DB for all given DataSetupStep
// Note: Different from PrepareData, PrepareDataWithScope doesn't accumulate changes to context
func PrepareDataWithScope(di *DI, scope DataSetupScope, steps ...DataSetupStep) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
scopedCtx, db := scope(ctx, t, di.DB)
for _, fn := range steps {
scopedCtx = fn(scopedCtx, t, db)
if t.Failed() {
return ctx, errors.New("test failed during data preparation")
}
}
return ctx, nil
}
}
// SetupUsingSQLFile returns a DataSetupStep that execute the provided sql file in given FS.
func SetupUsingSQLFile(fsys fs.FS, filenames ...string) DataSetupStep {
return func(ctx context.Context, t *testing.T, db *gorm.DB) context.Context {
g := gomega.NewWithT(t)
for _, filename := range filenames {
execSqlFile(ctx, fsys, db, g, filename)
}
return ctx
}
}
// SetupUsingSQLQueries returns a DataSetupStep that execute the provided sql queries.
func SetupUsingSQLQueries(queries ...string) DataSetupStep {
return func(ctx context.Context, t *testing.T, db *gorm.DB) context.Context {
g := gomega.NewWithT(t)
for _, q := range queries {
r := db.WithContext(ctx).Exec(q)
g.Expect(r.Error).To(Succeed(), "table preparation should be able to run SQL '%s'", q)
if t.Failed() {
return ctx
}
}
return ctx
}
}
// SetupUsingModelSeedFile returns a DataSetupStep that load provided yaml file
// and parse it directly into provided model and save them.
// when "closures" is provided, it's invoked after seeding is done.
func SetupUsingModelSeedFile(fsys fs.FS, dest interface{}, filename string, closures ...func(ctx context.Context, db *gorm.DB)) DataSetupStep {
return func(ctx context.Context, t *testing.T, db *gorm.DB) context.Context {
g := gomega.NewWithT(t)
e := loadSeedData(fsys, dest, filename)
g.Expect(e).To(Succeed(), "data preparation should be able to parse model's seed file")
if t.Failed() {
return ctx
}
tx := db.WithContext(ctx).CreateInBatches(dest, 100)
g.Expect(tx.Error).To(Succeed(), "data preparation should be able to create models seed file")
if t.Failed() {
return ctx
}
for _, fn := range closures {
fn(ctx, db)
}
return ctx
}
}
// SetupTruncateTables returns a DataSetupStep that truncate given tables in the provided order
func SetupTruncateTables(tables ...string) DataSetupStep {
sqls := make([]string, len(tables))
for i, table := range tables {
sqls[i] = truncateTableSql(table)
}
return SetupUsingSQLQueries(sqls...)
}
// SetupDropTables returns a DataSetupStep that truncate given tables in single DROP TABLE IF EXISTS
func SetupDropTables(tables ...string) DataSetupStep {
tableLiterals := make([]string, len(tables))
for i := range tables {
tableLiterals[i] = fmt.Sprintf(`"%s"`, tables[i])
}
sql := fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE;`, strings.Join(tableLiterals, ", "))
return SetupUsingSQLQueries(sql)
}
// SetupOnce returns a DataSetupStep that run given DataSetupSteps within the given sync.Once.
// How sync.Once is scoped is up to caller. e.g. once per test, once per package execution, etc...
func SetupOnce(once *sync.Once, steps ...DataSetupStep) DataSetupStep {
return func(ctx context.Context, t *testing.T, db *gorm.DB) context.Context {
once.Do(func() {
for _, step := range steps {
ctx = step(ctx, t, db)
}
})
return ctx
}
}
// SetupWithGormScopes returns a DataSetupScope that applies given gorm scopes
func SetupWithGormScopes(scopes ...func(*gorm.DB) *gorm.DB) DataSetupScope {
return func(ctx context.Context, t *testing.T, db *gorm.DB) (context.Context, *gorm.DB) {
return ctx, db.Scopes(scopes...)
}
}
var sqlStatementSep = regexp.MustCompile(`(?m); *$`)
func execSqlFile(ctx context.Context, fsys fs.FS, db *gorm.DB, g *gomega.WithT, filename string) {
file, e := fsys.Open(filename)
g.Expect(e).To(Succeed(), "table preparation should be able to open SQL file '%s'", filename)
queries, e := io.ReadAll(file)
g.Expect(e).To(Succeed(), "table preparation should be able to read SQL file '%s'", filename)
for _, q := range sqlStatementSep.Split(string(queries), -1) {
q = strings.TrimSpace(q)
if q == "" {
continue
}
r := db.WithContext(ctx).Exec(q)
g.Expect(r.Error).To(Succeed(), "table preparation should be able to run SQL file '%s'", filename)
}
}
func loadSeedData(fsys fs.FS, dest interface{}, filename string) (err error) {
data, err := fs.ReadFile(fsys, filename)
if err != nil {
return
}
err = yaml.Unmarshal(data, dest)
return
}
func truncateTableSql(table string) string {
return fmt.Sprintf(`TRUNCATE TABLE "%s" CASCADE;`, table)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dbtest
import (
"context"
"flag"
"fmt"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"github.com/cockroachdb/copyist"
"go.uber.org/fx"
"io"
"sync"
"testing"
)
const (
flagCopyistRecordMode = "record"
)
type copyistCK struct{}
var (
ctxKeyCopyistCloser = copyistCK{}
regOnce = sync.Once{}
)
const (
dsKeyHost = "host"
dsKeyPort = "port"
dsKeyDB = "dbname"
dsKeySslMode = "sslmode"
dsKeyUsername = "user"
dsKeyPassword = "password"
)
func withDB(mode mode, dbName string, opts []DBOptions) []test.Options {
setCopyistModeFlag(mode)
// prepare options
opt := DBOption{
Host: "127.0.0.1",
Port: 26257,
DBName: dbName,
Username: "root",
}
for _, fn := range opts {
fn(&opt)
}
return []test.Options{
test.Setup(initializePostgresMock()),
test.Setup(openCopyistConn(&opt)),
test.Teardown(closeCopyistConn()),
apptest.WithFxOptions(
fx.Provide(testGormDialectorProvider(&opt)),
),
apptest.WithProperties(
fmt.Sprintf("data.db.host: %s", opt.Host),
fmt.Sprintf("data.db.port: %d", opt.Port),
fmt.Sprintf("data.db.database: %s", opt.DBName),
fmt.Sprintf("data.db.username: %s", opt.Username),
fmt.Sprintf("data.db.password: %s", opt.Password),
),
}
}
func initializePostgresMock() test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
regOnce.Do(func() {
copyist.Register("postgres")
})
return ctx, nil
}
}
func setCopyistModeFlag(mode mode) {
switch mode {
case modePlayback:
mustSetFlag(flagCopyistRecordMode, "false")
case modeRecord:
mustSetFlag(flagCopyistRecordMode, "true")
default:
}
}
func openCopyistConn(opt *DBOption) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
if opt.DBName == "" {
return nil, fmt.Errorf("DBName is required for recording mode")
}
closer := copyist.Open(t)
if closer == nil {
return ctx, nil
}
return context.WithValue(ctx, ctxKeyCopyistCloser, closer), nil
}
}
func closeCopyistConn() test.TeardownFunc {
return func(ctx context.Context, t *testing.T) error {
switch v := ctx.Value(ctxKeyCopyistCloser).(type) {
case io.Closer:
return v.Close()
}
return nil
}
}
func mustSetFlag(name, value string) {
e := flag.Set(name, value)
if e != nil {
panic(e)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dbtest
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/data/postgresql"
"go.uber.org/fx"
"gorm.io/driver/postgres"
"gorm.io/gorm"
gormtest "gorm.io/gorm/utils/tests"
"strings"
)
/*****************************
gorm postgres Dialetor
*****************************/
type dialectorDI struct {
fx.In
}
func testGormDialectorProvider(opt *DBOption) func(di dialectorDI) gorm.Dialector {
return func(di dialectorDI) gorm.Dialector {
ssl := "disable"
if opt.SSL {
ssl = "enable"
}
options := map[string]interface{}{
dsKeyHost: opt.Host,
dsKeyPort: opt.Port,
dsKeyDB: opt.DBName,
dsKeySslMode: ssl,
}
if opt.Username != "" {
options[dsKeyUsername] = opt.Username
options[dsKeyPassword] = opt.Password
}
config := postgres.Config{
DriverName: "copyist_postgres",
DSN: toDSN(options),
}
return postgresql.NewGormDialectorWithConfig(config)
}
}
func toDSN(options map[string]interface{}) string {
opts := make([]string, 0)
for k, v := range options {
opt := fmt.Sprintf("%s=%v", k, v)
opts = append(opts, opt)
}
return strings.Join(opts, " ")
}
/****************************
gorm Noop Dialector
****************************/
type noopGormDialector struct {
gormtest.DummyDialector
}
func provideNoopGormDialector() gorm.Dialector {
return noopGormDialector{gormtest.DummyDialector{}}
}
func (d noopGormDialector) SavePoint(_ *gorm.DB, _ string) error {
return nil
}
func (d noopGormDialector) RollbackTo(_ *gorm.DB, _ string) error {
return nil
}
/*****************************
gorm cockroach error
*****************************/
func pqErrorTranslatorProvider() fx.Annotated {
return fx.Annotated{
Group: data.GormConfigurerGroup,
Target: func() data.ErrorTranslator {
return postgresql.PostgresErrorTranslator{}
},
}
}
/*****************************
gorm dry run
*****************************/
func enableGormDryRun(db *gorm.DB) {
db.DryRun = true
db.SkipDefaultTransaction = true
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dbtest
import (
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/data"
"github.com/cisco-open/go-lanai/pkg/data/tx"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"github.com/cisco-open/go-lanai/test/suitetest"
"github.com/cockroachdb/copyist"
"go.uber.org/fx"
)
//var logger = log.New("T.DB")
//go:embed defaults-dbtest.yml
var defaultConfigFS embed.FS
// EnableDBRecordMode Force enables DB recording mode.
// Normally recording mode should be enabled via `go test` argument `-record`
// IMPORTANT: when Record mode is enabled, all tests executing SQL against actual database.
// Or if Opensearch is being used, any queries to that will be executed against the real opensearch service.
// So use this mode on LOCAL DEV ONLY, and have the DB copied before executing
func EnableDBRecordMode() suitetest.PackageOptions {
return suitetest.Setup(func() error {
setCopyistModeFlag(modeRecord)
return nil
})
}
// WithDBPlayback enables DB SQL playback capabilities supported by `copyist`
// This mode requires apptest.Bootstrap to work, and should not be used together with WithNoopMocks
// Each top-level test should have corresponding recorded SQL responses in `testdata` folder, or the test will fail.
// To enable record mode, use `go test ... -record` at CLI, or do it programmatically with EnableDBRecordMode
// See https://github.com/cockroachdb/copyist for more details
func WithDBPlayback(dbName string, opts ...DBOptions) test.Options {
testOpts := withDB(modeAuto, dbName, opts)
testOpts = append(testOpts, withData()...)
return test.WithOptions(testOpts...)
}
// IsRecording returns true if copyist is in recording mode
func IsRecording() bool {
return copyist.IsRecording()
}
// WithNoopMocks create a noop tx.TxManager and a noop gorm.DB
// This mode requires apptest.Bootstrap to work, and should not be used together with WithDBPlayback
// Note: in this mode, gorm.DB's DryRun and SkipDefaultTransaction are enabled
func WithNoopMocks() test.Options {
testOpts := withData()
testOpts = append(testOpts, apptest.WithFxOptions(
fx.Provide(provideNoopTxManager),
fx.Provide(provideNoopGormDialector),
fx.Invoke(enableGormDryRun),
))
return test.WithOptions(testOpts...)
}
func withData() []test.Options {
return []test.Options{
apptest.WithModules(data.Module, tx.Module),
apptest.WithFxOptions(
appconfig.FxEmbeddedDefaults(defaultConfigFS),
fx.Provide(pqErrorTranslatorProvider()),
),
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dbtest
import (
"context"
"database/sql"
"github.com/cisco-open/go-lanai/pkg/data/tx"
"gorm.io/gorm"
)
type noopTxManager struct {}
func provideNoopTxManager() tx.TxManager {
return noopTxManager{}
}
func (m noopTxManager) Transaction(ctx context.Context, fn tx.TxFunc, _ ...*sql.TxOptions) error {
return fn(m.mockTxContext(ctx))
}
func (m noopTxManager) WithDB(_ *gorm.DB) tx.GormTxManager {
return m
}
func (m noopTxManager) Begin(ctx context.Context, _ ...*sql.TxOptions) (context.Context, error) {
return m.mockTxContext(ctx), nil
}
func (m noopTxManager) Rollback(ctx context.Context) (context.Context, error) {
if tc, ok := ctx.(tx.TxContext); ok {
return tc.Parent(), nil
}
return ctx, nil
}
func (m noopTxManager) Commit(ctx context.Context) (context.Context, error) {
if tc, ok := ctx.(tx.TxContext); ok {
return tc.Parent(), nil
}
return ctx, nil
}
func (m noopTxManager) SavePoint(ctx context.Context, _ string) (context.Context, error) {
return ctx, nil
}
func (m noopTxManager) RollbackTo(ctx context.Context, _ string) (context.Context, error) {
return ctx, nil
}
func (m noopTxManager) mockTxContext(ctx context.Context) context.Context {
return &mockedGormContext{
mockedTxContext: mockedTxContext{
Context: ctx,
},
db: &gorm.DB{
Config: &gorm.Config{},
Statement: &gorm.Statement{},
},
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package embedded
import (
"context"
"crypto/tls"
"fmt"
"github.com/alicebob/miniredis/v2"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"github.com/cisco-open/go-lanai/test/suitetest"
"math/rand"
"testing"
"time"
)
var kCtxEmbeddedRedis = struct{}{}
/*******************
Public
*******************/
type RedisOptions func(cfg *RedisConfig)
type RedisConfig struct {
// Port must between 32768 and 65535
Port int
// TLS when set, the redis server is run in TLS mode.
// Note: duo to internal implementation, When running in TLS mode, the Port is ignored
TLS *tls.Config
}
// Redis start redis at random port (32768-65535) on test package level.
// The actual port get be get using CurrentRedisPort
func Redis(opts ...RedisOptions) suitetest.PackageOptions {
return suitetest.TestOptions(WithRedis(opts...))
}
// WithRedis start redis at random port (32768-65535) on per test basis
// The actual port get be get using CurrentRedisPort
func WithRedis(opts ...RedisOptions) test.Options {
//nolint:gosec // Not security related
r := rand.New(rand.NewSource(time.Now().UnixNano()))
cfg := RedisConfig{
Port: 0x7fff + r.Intn(0x7fff) + 1,
}
for _, fn := range opts {
fn(&cfg)
}
return redisWithConfig(&cfg)
}
func EnableTLS(certs ...func(src *TLSCerts)) RedisOptions {
tlsCfg, e := ServerTLSWithCerts(certs...)
if e != nil {
logger.Warnf(`unable to enable TLS: %v`, e)
}
return func(cfg *RedisConfig) {
cfg.TLS = tlsCfg
}
}
// RedisWithPort start redis at given port (must between 32768 and 65535) on test package level.
// Deprecated, use Redis(...) to set RedisConfig.Port
func RedisWithPort(port int) suitetest.PackageOptions {
return Redis(func(cfg *RedisConfig) {
cfg.Port = port
})
}
// CurrentRedisPort getter to return embedded redis port. returns -1 if it's not initialized or started
func CurrentRedisPort(ctx context.Context) (port int) {
port = -1
srv, ok := ctx.Value(kCtxEmbeddedRedis).(*miniredis.Miniredis)
if !ok {
return
}
ret := doWithEmbeddedRedis(srv, func(srv *miniredis.Miniredis) interface{} {
return srv.Server().Addr().Port
})
switch v := ret.(type) {
case int:
return v
}
return
}
// CurrentRedisServer getter to return embedded redis. returns nil if it's not initialized or started
func CurrentRedisServer(ctx context.Context) *miniredis.Miniredis {
srv, _ := ctx.Value(kCtxEmbeddedRedis).(*miniredis.Miniredis)
return srv
}
/*******************
Internals
*******************/
// redisWithConfig start redis based on given RedisConfig
func redisWithConfig(cfg *RedisConfig) test.Options {
return test.WithOptions(
test.Setup(func(ctx context.Context, t *testing.T) (context.Context, error) {
s, e := startEmbeddedRedis(cfg)
if e != nil {
return ctx, e
}
return context.WithValue(ctx, kCtxEmbeddedRedis, s), nil
}),
apptest.WithDynamicProperties(map[string]apptest.PropertyValuerFunc{
"redis.addrs": func(ctx context.Context) interface{} {
return fmt.Sprintf("127.0.0.1:%d", CurrentRedisPort(ctx))
},
}),
test.Teardown(func(ctx context.Context, t *testing.T) error {
if s, ok := ctx.Value(kCtxEmbeddedRedis).(*miniredis.Miniredis); ok {
stopEmbeddedRedis(s)
}
return nil
}),
)
}
func startEmbeddedRedis(cfg *RedisConfig) (server *miniredis.Miniredis, err error) {
switch {
case cfg.TLS != nil:
// TLS mode
server = miniredis.NewMiniRedis()
err = server.StartTLS(cfg.TLS)
case cfg.Port <= 0x7fff && cfg.Port != 0:
err = fmt.Errorf("invalid embedded redis port [%d], should be > 0x7fff", cfg.Port)
default:
// Default mode
server = miniredis.NewMiniRedis()
addr := fmt.Sprintf("127.0.0.1:%d", cfg.Port)
err = server.StartAddr(addr)
}
if err == nil {
logger.Infof("Embedded Redis started at %s", server.Addr())
}
return
}
func stopEmbeddedRedis(server *miniredis.Miniredis) {
if server != nil {
addr := server.Addr()
server.Close()
logger.Infof("Embedded Redis stopped at %s", addr)
}
}
// doWithEmbeddedRedis perform locking on miniredis.Miniredis, bail the operation if server is not started
func doWithEmbeddedRedis(server *miniredis.Miniredis, fn func(srv *miniredis.Miniredis) interface{}) interface{} {
if server == nil {
return nil
}
server.Lock()
defer server.Unlock()
if s := server.Server(); s != nil {
return fn(server)
}
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package embedded
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/fs"
"os"
)
type TLSCerts struct {
FS fs.FS
Cert string
Key string
CAs []string
}
// ServerTLSWithCerts construct a tls.Config with certificates in a given filesystem.
// The setup Server TLS, following config are required:
// - filesystem to load files from. Default to "."
// - path of certificate file in PEM format, default to "testdata/server.crt"
// - path of certificate private key file in unencrypted PEM format, default to "testdata/server.key"
// - path of at least one CA certificate in PEM format, default to "testdata/ca.crt"
// Note: if any file is missing or not readable, the result tls.Config might not works as expected
func ServerTLSWithCerts(opts ...func(src *TLSCerts)) (*tls.Config, error) {
src := TLSCerts{
FS: os.DirFS("."),
Cert: "testdata/server.crt",
Key: "testdata/server.key",
CAs: []string{"testdata/ca.crt"},
}
for _, fn := range opts {
fn(&src)
}
// start to load
caPool := x509.NewCertPool()
for _, path := range src.CAs {
pemBytes, e := fs.ReadFile(src.FS, path)
if e != nil {
return nil, fmt.Errorf("unable to read CA file [%s]: %v", path, e)
}
caPool.AppendCertsFromPEM(pemBytes)
}
certBytes, e := fs.ReadFile(src.FS, src.Cert)
if e != nil {
return nil, fmt.Errorf("unable to read certificate file [%s]: %v", src.Cert, e)
}
keyBytes, e := fs.ReadFile(src.FS, src.Key)
if e != nil {
return nil, fmt.Errorf("unable to read private key file [%s]: %v", src.Key, e)
}
cert, e := tls.X509KeyPair(certBytes, keyBytes)
if e != nil {
return nil, fmt.Errorf("unable to parse certificate: %v", e)
}
return &tls.Config{
MinVersion: tls.VersionTLS13,
RootCAs: caPool,
Certificates: []tls.Certificate{cert},
}, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package examples
import (
"context"
"github.com/cisco-open/go-lanai/pkg/integrate/httpclient"
"github.com/cisco-open/go-lanai/pkg/integrate/security/scope"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/rest"
"net/http"
)
type ExampleRequest struct {
UseSystemAccount bool `json:"sysAcct" form:"sys_acct"`
Username string `json:"user" form:"user"`
}
type ExampleController struct {
Service *ExampleService
}
func NewExampleController(svc *ExampleService) web.Controller {
return &ExampleController{
Service: svc,
}
}
func (c *ExampleController) Mappings() []web.Mapping {
return []web.Mapping{
rest.Get("/remote").Condition(access.RequirePermissions("DUMMY_PERMISSION")).EndpointFunc(c.Remote).Build(),
rest.Post("/remote").Condition(access.RequirePermissions("DUMMY_PERMISSION")).EndpointFunc(c.Remote).Build(),
}
}
func (c *ExampleController) Remote(ctx context.Context, req *ExampleRequest) (interface{}, error) {
switch {
case len(req.Username) == 0:
return c.Service.CallRemoteWithCurrentContext(ctx)
case req.UseSystemAccount:
return c.Service.CallRemoteWithSystemAccount(ctx, req.Username)
default:
return c.Service.CallRemoteWithoutSystemAccount(ctx, req.Username)
}
}
type ExampleService struct {
HttpClient httpclient.Client
}
func NewExampleService(client httpclient.Client) (*ExampleService, error) {
client, e := client.WithService("usermanagementgoservice")
if e != nil {
return nil, e
}
return &ExampleService{
HttpClient: client,
}, nil
}
// CallRemoteWithSystemAccount switch to given username using system account and make remote HTTP call
func (s *ExampleService) CallRemoteWithSystemAccount(ctx context.Context, username string) (ret interface{}, err error) {
e := scope.Do(ctx, func(ctx context.Context) {
ret, err = s.performRemoteHttpCall(ctx)
}, scope.UseSystemAccount(), scope.WithUsername(username))
if e != nil {
return nil, e
}
return
}
// CallRemoteWithoutSystemAccount switch to given username directly and make remote HTTP call
func (s *ExampleService) CallRemoteWithoutSystemAccount(ctx context.Context, username string) (ret interface{}, err error) {
e := scope.Do(ctx, func(ctx context.Context) {
ret, err = s.performRemoteHttpCall(ctx)
}, scope.WithUsername(username))
if e != nil {
return nil, e
}
return
}
// CallRemoteWithCurrentContext make remote HTTP call using current security context
func (s *ExampleService) CallRemoteWithCurrentContext(ctx context.Context) (ret interface{}, err error) {
return s.performRemoteHttpCall(ctx)
}
func (s *ExampleService) performRemoteHttpCall(ctx context.Context) (ret interface{}, err error) {
resp, e := s.HttpClient.Execute(ctx, httpclient.NewRequest("/api/v8/users/current", http.MethodGet))
if e != nil {
err = e
return
}
ret = resp.Body
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ittest
import (
"context"
"errors"
"flag"
"fmt"
"github.com/cisco-open/go-lanai/pkg/integrate/httpclient"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"github.com/cisco-open/go-lanai/test/suitetest"
"go.uber.org/fx"
"gopkg.in/dnaeon/go-vcr.v3/cassette"
"gopkg.in/dnaeon/go-vcr.v3/recorder"
"net/http"
"strconv"
"testing"
"time"
)
func init() {
// try register "record" flag, it may fail if it's already registered
if flag.Lookup(CLIRecordModeFlag) == nil {
flag.Bool(CLIRecordModeFlag, true, "record external interaction")
}
}
type RecorderDI struct {
fx.In
Recorder *recorder.Recorder
RecorderOption *recorder.Options
RecorderMatcher cassette.MatcherFunc
HTTPVCROption *HTTPVCROption
}
type recorderDI struct {
fx.In
RecorderDI
HTTPRecorder *HttpRecorder
}
// WithHttpPlayback enables remote HTTP server playback capabilities supported by `httpvcr`
// This mode requires apptest.Bootstrap to work
// Each top-level test should have corresponding recorded HTTP responses in `testdata` folder, or the test will fail.
// To enable record mode, use `go test ... --record-http` at CLI, or do it programmatically with HttpRecordingMode
// See https://github.com/cockroachdb/copyist for more details
func WithHttpPlayback(t *testing.T, opts ...HTTPVCROptions) test.Options {
opts = append([]HTTPVCROptions{
HttpRecordName(t.Name()),
SanitizeHttpRecord(),
FixedHttpRecordDuration(DefaultHTTPDuration),
}, opts...)
var di recorderDI
testOpts := []test.Options{
apptest.WithDI(&di),
apptest.WithFxOptions(
fx.Provide(
httpRecorderProvider(opts),
),
fx.Invoke(httpRecorderCleanup),
),
test.SubTestSetup(recorderDISetup(&di)),
test.SubTestTeardown(recorderReset()),
}
return test.WithOptions(testOpts...)
}
/****************************
Functions
****************************/
// Recorder extract HttpRecorder from given context. If HttpRecorder is not available, it returns nil
func Recorder(ctx context.Context) *HttpRecorder {
if rec, ok := ctx.Value(ckRecorder).(*HttpRecorder); ok && rec.Recorder != nil {
return rec
}
return nil
}
// Client extract http.Client that provided by Recorder. If Recorder is not available, it returns nil
func Client(ctx context.Context) *http.Client {
if rec, ok := ctx.Value(ckRecorder).(*HttpRecorder); ok && rec.Recorder != nil {
return rec.GetDefaultClient()
}
return nil
}
// IsRecording returns true if HTTP VCR is in recording mode
func IsRecording(ctx context.Context) bool {
if rec, ok := ctx.Value(ckRecorder).(*HttpRecorder); ok && rec.Recorder != nil {
return rec.IsRecording()
}
return false
}
// AdditionalMatcherOptions temporarily add additional RecordMatcherOptions to the current test context on top of test's HTTPVCROptions.
// Any changes made with this method can be reset via ResetRecorder. When using with WithHttpPlayback(), the reset is automatic per sub-test
// Note: The additional options take effect within the scope of sub-test. For test level options, use HttpRecordMatching.
func AdditionalMatcherOptions(ctx context.Context, opts ...RecordMatcherOptions) {
rec, ok := ctx.Value(ckRecorder).(*HttpRecorder)
if !ok || rec.Recorder == nil {
return
}
// merge matching options
newOpts := make([]RecordMatcherOptions, len(rec.Options.RecordMatching), len(rec.Options.RecordMatching)+len(opts))
copy(newOpts, rec.Options.RecordMatching)
newOpts = append(newOpts, opts...)
// construct and set new matcher
newMatcher := newCassetteMatcherFunc(newOpts, rec.Options.indexAwareWrapper)
rec.SetMatcher(newMatcher)
}
// ResetRecorder revert the change made by AdditionalMatcherOptions.
// Note: If tests configured via WithHttpPlayback, this method is automatically invoked at sub-test teardown.
func ResetRecorder(ctx context.Context) {
rec, ok := ctx.Value(ckRecorder).(*HttpRecorder)
if !ok || rec.Recorder == nil {
return
}
rec.SetMatcher(rec.InitMatcher)
}
// StopRecorder stops the recorder extracted from the given context.
// Note: If tests configured via WithHttpPlayback, this method is automatically invoked at test teardown.
func StopRecorder(ctx context.Context) error {
rec, ok := ctx.Value(ckRecorder).(*HttpRecorder)
if !ok || rec.Recorder == nil {
return fmt.Errorf("failed to stop recorder: no recorder found in context")
}
return rec.Stop()
}
/*************************
Options
*************************/
// PackageHttpRecordingMode returns a suitetest.PackageOptions that enables HTTP recording mode for the entire package.
// This is usually used in TestMain function.
// Note: this option has no effect to tests using DisableHttpRecordingMode
// e.g.
// <code>
//
// func TestMain(m *testing.M) {
// suitetest.RunTests(m,
// PackageHttpRecordingMode(),
// )
// }
//
// </code>
func PackageHttpRecordingMode() suitetest.PackageOptions {
return suitetest.Setup(func() error {
return flag.Set(CLIRecordModeFlag, "true")
})
}
// HttpRecordingMode returns a HTTPVCROptions that turns on Recording mode.
// Normally recording mode should be enabled via `go test` argument `-record-http`
// Note: Record mode is forced off if flag is set to "-record-http=false" explicitly
// IMPORTANT: When Record mode is enabled, all sub tests interact with actual HTTP remote service.
//
// So use this mode on LOCAL DEV ONLY
func HttpRecordingMode() HTTPVCROptions {
return func(opt *HTTPVCROption) {
opt.Mode = ModeRecording
}
}
// DisableHttpRecordingMode returns a HTTPVCROptions that force replaying mode regardless the command line flag
func DisableHttpRecordingMode() HTTPVCROptions {
return func(opt *HTTPVCROption) {
opt.Mode = ModeReplaying
}
}
// HttpRecordName returns a HTTPVCROptions that set HTTP record's name
func HttpRecordName(name string) HTTPVCROptions {
return func(opt *HTTPVCROption) {
opt.Name = name
}
}
// HttpRecordMatching returns a HTTPVCROptions that allows custom matching of recorded requests
func HttpRecordMatching(opts ...RecordMatcherOptions) HTTPVCROptions {
return func(opt *HTTPVCROption) {
opt.RecordMatching = append(opt.RecordMatching, opts...)
}
}
// HttpRecorderHooks returns a HTTPVCROptions that adds recording hooks. If given hooks also implementing order.Ordered,
// the order will be respected
func HttpRecorderHooks(hooks ...RecorderHook) HTTPVCROptions {
return func(opt *HTTPVCROption) {
LOOP:
for i := range hooks {
for j := range opt.Hooks {
if hooks[i].Name() == opt.Hooks[j].Name() {
opt.Hooks[j] = hooks[i]
continue LOOP
}
}
opt.Hooks = append(opt.Hooks, hooks[i])
}
}
}
// DisableHttpRecorderHooks returns a HTTPVCROptions that removes installed hooks by name
func DisableHttpRecorderHooks(names ...string) HTTPVCROptions {
return func(opt *HTTPVCROption) {
LOOP:
for i := range names {
for j := range opt.Hooks {
if names[i] == opt.Hooks[j].Name() {
opt.Hooks[j] = opt.Hooks[len(opt.Hooks)-1]
opt.Hooks = opt.Hooks[:len(opt.Hooks)-1]
continue LOOP
}
}
}
}
}
// HttpRecordIgnoreHost convenient HTTPVCROptions that would ignore host when matching recorded requests,
// equivalent to HttpRecordMatching(IgnoreHost())
func HttpRecordIgnoreHost() HTTPVCROptions {
return HttpRecordMatching(IgnoreHost())
}
// HttpRecordOrdering toggles HTTP interactions order matching.
// When enforced, HTTP interactions have to happen in the recorded order.
// Otherwise, HTTP interactions can happen in any order, but each matched record can only replay once
// By default, record ordering is enabled
func HttpRecordOrdering(enforced bool) HTTPVCROptions {
return func(opt *HTTPVCROption) {
if enforced && opt.indexAwareWrapper == nil {
opt.indexAwareWrapper = newIndexAwareMatcherWrapper()
} else if !enforced {
opt.indexAwareWrapper = nil
}
}
}
// DisableHttpRecordOrdering disable HTTP interactions order matching.
// By default, HTTP interactions have to happen in the recorded order.
// When this option is used, HTTP interactions can happen in any order. However, each matched record can only replay once
func DisableHttpRecordOrdering() HTTPVCROptions {
return HttpRecordOrdering(false)
}
// HttpTransport override the RealTransport during recording mode. This option has no effect in playback mode
func HttpTransport(transport http.RoundTripper) HTTPVCROptions {
return func(opt *HTTPVCROption) {
opt.RealTransport = transport
}
}
// ApplyHttpLatency apply recorded HTTP latency. By default, HTTP latency is not applied for faster test run.
// This option has no effect in recording mode.
func ApplyHttpLatency() HTTPVCROptions {
return func(opt *HTTPVCROption) {
opt.SkipRequestLatency = false
}
}
// SanitizeHttpRecord install a hook to sanitize request and response before they are saved in file.
// See SanitizingHook for details.
func SanitizeHttpRecord() HTTPVCROptions {
return HttpRecorderHooks(SanitizingHook())
}
// FixedHttpRecordDuration install a hook to set a fixed duration on interactions before they are saved.
// If the duration is less or equal to 0, the actual latency will be recorded.
// When HTTPVCROption.SkipRequestLatency is set to false (via ApplyHttpLatency option), the recorded duration will be applied during playback
// See FixedDurationHook for details.
// This option has no effect in playback mode.
func FixedHttpRecordDuration(duration time.Duration) HTTPVCROptions {
if duration <= 0 {
return DisableHttpRecorderHooks(HookNameFixedDuration)
}
return HttpRecorderHooks(FixedDurationHook(duration))
}
/****************************
RawRecorder Aware Context
****************************/
type recorderCtxKey struct{}
var ckRecorder = recorderCtxKey{}
type recorderAwareContext struct {
context.Context
recorder *HttpRecorder
}
func contextWithRecorder(parent context.Context, rec *HttpRecorder) context.Context {
if rec == nil {
return parent
}
return &recorderAwareContext{
Context: parent,
recorder: rec,
}
}
func (c *recorderAwareContext) Value(k interface{}) interface{} {
switch k {
case ckRecorder:
return c.recorder
default:
return c.Context.Value(k)
}
}
func recorderDISetup(di *recorderDI) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
return contextWithRecorder(ctx, di.HTTPRecorder), nil
}
}
// recorderReset automatically reset recorder to original state in case it changed
func recorderReset() test.TeardownFunc {
return func(ctx context.Context, t *testing.T) error {
ResetRecorder(ctx)
return nil
}
}
/*************************
HttpRecorder
*************************/
// HttpRecorder wrapper of recorder.RawRecorder, used to hold some value that normally inaccessible via wrapped recorder.RawRecorder.
// Note: This type is for other test utilities to re-configure recorder.RawRecorder
type HttpRecorder struct {
*recorder.Recorder
RawOptions *recorder.Options
InitMatcher cassette.MatcherFunc
Options *HTTPVCROption
}
// ContextWithNewHttpRecorder is a convenient function that create a new HTTP recorder and store it in context.
// The returned context can be used with context value accessor such as Client(ctx), IsRecording(ctx), AdditionalMatcherOptions(ctx), etc.
// See NewHttpRecorder
func ContextWithNewHttpRecorder(ctx context.Context, opts ...HTTPVCROptions) (context.Context, error) {
rec, e := NewHttpRecorder(opts...)
if e != nil {
return nil, e
}
return contextWithRecorder(ctx, rec), nil
}
// NewHttpRecorder create a new HttpRecorder. Commonly used by:
// - other test utilities that relies on http recording. (e.g. opensearchtest, consultest, etc.)
// - unit tests that doesn't bootstrap dependency injection
func NewHttpRecorder(opts ...HTTPVCROptions) (*HttpRecorder, error) {
opt := HTTPVCROption{
SavePath: "testdata",
Hooks: []RecorderHook{
InteractionIndexAwareHook(),
},
SkipRequestLatency: true,
indexAwareWrapper: newIndexAwareMatcherWrapper(), // enforce order
}
for _, fn := range opts {
fn(&opt)
}
rawOpts := toRecorderOptions(opt)
rec, e := recorder.NewWithOptions(rawOpts)
if e != nil {
return nil, e
}
// set matchers
matcher := newCassetteMatcherFunc(opt.RecordMatching, opt.indexAwareWrapper)
rec.SetMatcher(matcher)
//set hooks
order.SortStable(opt.Hooks, order.OrderedFirstCompare)
for _, h := range opt.Hooks {
rec.AddHook(h.Handler(), h.Kind())
}
return &HttpRecorder{
Recorder: rec,
RawOptions: rawOpts,
InitMatcher: matcher,
Options: &opt,
}, nil
}
/*************************
Internals
*************************/
type vcrDI struct {
fx.In
VCROptions []HTTPVCROptions `group:"http-vcr"`
}
type vcrOut struct {
fx.Out
HTTPRecorder *HttpRecorder
RawRecorder *recorder.Recorder
CassetteMatcher cassette.MatcherFunc
HttpVCROption *HTTPVCROption
RawRecorderOption *recorder.Options
HttpClientCustomizer httpclient.ClientCustomizer `group:"http-client"`
}
func httpRecorderProvider(opts []HTTPVCROptions) func(di vcrDI) (vcrOut, error) {
return func(di vcrDI) (vcrOut, error) {
finalOpts := append(opts, di.VCROptions...)
rec, e := NewHttpRecorder(finalOpts...)
if e != nil {
return vcrOut{}, e
}
return vcrOut{
HTTPRecorder: rec,
RawRecorder: rec.Recorder,
CassetteMatcher: rec.InitMatcher,
HttpVCROption: rec.Options,
RawRecorderOption: rec.RawOptions,
HttpClientCustomizer: httpclient.ClientCustomizerFunc(func(opt *httpclient.ClientOption) {
opt.HTTPClient = rec.GetDefaultClient()
}),
}, nil
}
}
func findBoolFlag(name string) (ret *bool) {
flag.Visit(func(f *flag.Flag) {
if f.Name != name {
return
}
var b bool
b, e := strconv.ParseBool(f.Value.String())
if e != nil {
b = true // default to true
}
ret = &b
})
return
}
func toRecorderOptions(opt HTTPVCROption) *recorder.Options {
cliFlag := findBoolFlag(CLIRecordModeFlag)
mode := recorder.ModeReplayOnly
switch opt.Mode {
case ModeRecording:
if cliFlag == nil || *cliFlag {
mode = recorder.ModeRecordOnly
}
case ModeCommandline:
if cliFlag != nil && *cliFlag {
mode = recorder.ModeRecordOnly
}
default:
}
name := opt.Name + ".httpvcr"
if len(opt.SavePath) != 0 {
name = opt.SavePath + "/" + opt.Name + ".httpvcr"
}
return &recorder.Options{
CassetteName: name,
Mode: mode,
RealTransport: opt.RealTransport,
SkipRequestLatency: opt.SkipRequestLatency,
}
}
func newCassetteMatcherFunc(opts []RecordMatcherOptions, indexAwareMatcher *indexAwareMatcherWrapper) cassette.MatcherFunc {
matcherFn := NewRecordMatcher(opts...)
if indexAwareMatcher == nil {
return wrapRecordRequestMatcher(matcherFn)
}
return wrapRecordRequestMatcher(indexAwareMatcher.MatcherFunc(RecordMatcherFunc(matcherFn)))
}
func httpRecorderCleanup(lc fx.Lifecycle, rec *recorder.Recorder) {
lc.Append(fx.Hook{
OnStop: func(ctx context.Context) error {
return rec.Stop()
},
})
}
func wrapRecordRequestMatcher(fn GenericMatcherFunc[*http.Request, cassette.Request]) cassette.MatcherFunc {
return func(out *http.Request, record cassette.Request) bool {
if e := fn(out, record); e != nil {
if !errors.Is(e, errInteractionIDMismatch) {
logger.Debugf("HTTP interaction missing: %s - %v: expect %s, but got %s",
record.Headers.Get(xInteractionIndexHeader), e,
fmt.Sprintf(`%s "%s"`, record.Method, record.URL),
fmt.Sprintf(`%s "%s"`, out.Method, out.URL.String()))
}
return false
}
return true
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ittest
import (
"fmt"
"github.com/ghodss/yaml"
"net/http"
"net/url"
"os"
"strconv"
)
type V1Cassette struct {
Version int `json:"version"`
Interactions []V1Interaction `json:"interactions"`
}
type V1Interaction struct {
Request map[string]interface{} `json:"request"`
Response map[string]interface{} `json:"response"`
}
type V2Cassette struct {
Version int `json:"version"`
Interactions []V2Interaction `json:"interactions"`
}
type V2Interaction struct {
ID int `json:"id"`
Request map[string]interface{} `json:"request"`
Response map[string]interface{} `json:"response"`
}
// ConvertCassetteFileV1toV2 is a utility function that help with migrating from httpvcr/v3 (using version 1 format)
// to httpvcr/v3 (using version 2 format).
// Note: Usually test authors should re-record interactions instead of using this utility. However, there might be cases
// that re-recording is not possible due to lack of remote server setup.
func ConvertCassetteFileV1toV2(src, dest string) error {
srcBytes, e := os.ReadFile(src)
if e != nil {
return fmt.Errorf("unable to convert record file: %v", e)
}
var v1 V1Cassette
if e := yaml.Unmarshal(srcBytes, &v1); e != nil {
return fmt.Errorf("unable to convert record file, invalid source file: %v", e)
}
v2 := convertCassetteV1ToV2(v1)
v2bytes, e := yaml.Marshal(v2)
if e != nil {
return fmt.Errorf("unable to convert record file: %v", e)
}
//nolint:gosec // G306: Expect WriteFile permissions to be 0600 or less - 0600 is too low
if e := os.WriteFile(dest, v2bytes, 0644); e != nil {
return fmt.Errorf("unable to convert record file, failed to write to destination: %v", e)
}
return nil
}
func convertCassetteV1ToV2(v1 V1Cassette) V2Cassette {
v2 := V2Cassette{
Version: 2,
Interactions: make([]V2Interaction, len(v1.Interactions)),
}
for i, record := range v1.Interactions {
v2.Interactions[i] = convertInteractionV1ToV2(i, record)
}
return v2
}
func convertInteractionV1ToV2(id int, v1 V1Interaction) V2Interaction {
v2 := V2Interaction{
ID: id,
Request: v1.Request,
Response: v1.Response,
}
// Add host field to each request (required for matching) if possible
if rawUrl, ok := v1.Request["url"].(string); ok {
parsed, e := url.Parse(rawUrl)
if e == nil {
v2.Request["host"] = parsed.Host
}
}
// Add Interaction Index if possible
var headers http.Header
if rawHeaders, ok := v1.Request["headers"].(map[string][]string); !ok {
headers = http.Header{}
} else {
headers = rawHeaders
}
headers.Set(xInteractionIndexHeader, strconv.Itoa(id))
v2.Request["headers"] = headers
v2.Request["order"] = id
// remove duration if empty
if v, ok := v1.Response["duration"].(string); ok && len(v) == 0 {
delete(v2.Response, "duration")
}
return v2
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ittest
import (
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/spyzhov/ajson"
"gopkg.in/dnaeon/go-vcr.v3/cassette"
"gopkg.in/dnaeon/go-vcr.v3/recorder"
"mime"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
)
const (
DefaultHost = "webservice"
)
const (
HookNameIndexAware = "index-aware"
HookNameSanitize = "sanitize"
HookNameFixedDuration = "fixed-duration"
HookNameLocalhostRewrite = "localhost-rewrite"
)
/************************
Common
************************/
func NewRecorderHook(name string, fn recorder.HookFunc, kind recorder.HookKind) RecorderHook {
return recorderHook{
name: name,
hook: recorder.Hook{
Handler: fn,
Kind: kind,
},
}
}
type recorderHook struct {
name string
hook recorder.Hook
}
func (h recorderHook) Name() string {
return h.name
}
func (h recorderHook) Handler() recorder.HookFunc {
return h.hook.Handler
}
func (h recorderHook) Kind() recorder.HookKind {
return h.hook.Kind
}
func NewRecorderHookWithOrder(name string, fn recorder.HookFunc, kind recorder.HookKind, order int) RecorderHook {
return orderedRecorderHook{
recorderHook: recorderHook{
name: name,
hook: recorder.Hook{
Handler: fn,
Kind: kind,
},
},
order: order,
}
}
type orderedRecorderHook struct {
recorderHook
order int
}
func (w orderedRecorderHook) Order() int {
return w.order
}
/************************
Sanitizer
************************/
var (
HeaderSanitizers = map[string]ValueSanitizer{
"Authorization": RegExpValueSanitizer("^(?P<prefix>Basic |Bearer |Digest ).*|.*", "${prefix}******"),
"Date": SubstituteValueSanitizer("Fri, 19 Aug 2022 8:51:32 GMT"),
}
QuerySanitizers = map[string]ValueSanitizer{
"password": DefaultValueSanitizer(),
"secret": DefaultValueSanitizer(),
"nonce": DefaultValueSanitizer(),
"token": DefaultValueSanitizer(),
"access_token": DefaultValueSanitizer(),
}
BodySanitizers = map[string]ValueSanitizer{
"access_token": DefaultValueSanitizer(),
}
)
type ValueSanitizer func(any) any
func RegExpValueSanitizer(regex, repl string) ValueSanitizer {
pattern := regexp.MustCompile(regex)
return func(i any) any {
switch s := i.(type) {
case string:
return pattern.ReplaceAllString(s, repl)
default:
return i
}
}
}
func SubstituteValueSanitizer(repl any) ValueSanitizer {
return func(_ any) any {
return repl
}
}
func DefaultValueSanitizer() ValueSanitizer {
return SubstituteValueSanitizer("_hidden")
}
/************************
Hooks Functions
************************/
// InteractionIndexAwareHook inject interaction index into stored header:
// httpvcr store interaction's ID but doesn't expose it to cassette.MatcherFunc,
// so we need to store it in request for request matchers to access
func InteractionIndexAwareHook() RecorderHook {
fn := func(i *cassette.Interaction) error {
i.Request.Headers.Set(xInteractionIndexHeader, strconv.Itoa(i.ID))
return nil
}
return NewRecorderHook(HookNameIndexAware, fn, recorder.BeforeSaveHook)
}
// SanitizingHook is an HTTP VCR hook that sanitize values in header, query, body (x-form-urlencoded/json).
// Values to sanitize are globally configured via HeaderSanitizers, QuerySanitizers, BodySanitizers.
// Note: Sanitized values cannot be exactly matched. If the configuration of sanitizers is changed, make sure
//
// to configure fuzzy matching accordingly.
//
// See NewRecordMatcher, FuzzyHeaders, FuzzyQueries, FuzzyForm and FuzzyJsonPaths
func SanitizingHook() RecorderHook {
reqJsonPaths := parseJsonPaths(FuzzyRequestJsonPaths.Values())
respJsonPaths := parseJsonPaths(FuzzyResponseJsonPaths.Values())
fn := func(i *cassette.Interaction) error {
i.Request.Headers = sanitizeHeaders(i.Request.Headers, FuzzyRequestHeaders)
i.Request.URL = sanitizeUrl(i.Request.URL, FuzzyRequestQueries)
switch mediaType(i.Request.Headers) {
case "application/x-www-form-urlencoded":
i.Request.Body = sanitizeRequestForm(&i.Request, FuzzyRequestQueries)
case "application/json":
i.Request.Body = sanitizeJsonBody(i.Request.Body, BodySanitizers, reqJsonPaths)
}
i.Response.Headers = sanitizeHeaders(i.Response.Headers, FuzzyResponseHeaders)
switch mediaType(i.Response.Headers) {
case "application/json":
i.Response.Body = sanitizeJsonBody(i.Response.Body, BodySanitizers, respJsonPaths)
}
return nil
}
return NewRecorderHookWithOrder(HookNameSanitize, fn, recorder.BeforeSaveHook, 0)
}
// LocalhostRewriteHook changes the host of request to a pre-defined constant if it is localhost, in order to avoid randomness
func LocalhostRewriteHook() RecorderHook {
fn := func(i *cassette.Interaction) error {
if strings.HasPrefix(i.Request.Host, "localhost") || strings.HasPrefix(i.Request.Host, "127.0.0.1") {
i.Request.URL = strings.Replace(i.Request.URL, i.Request.Host, DefaultHost, 1)
i.Request.Host = DefaultHost
}
return nil
}
return NewRecorderHook(HookNameLocalhostRewrite, fn, recorder.BeforeSaveHook)
}
// FixedDurationHook changes the duration of record HTTP interaction to constant, to avoid randomness
func FixedDurationHook(duration time.Duration) RecorderHook {
fn := func(i *cassette.Interaction) error {
i.Response.Duration = duration
return nil
}
return NewRecorderHook(HookNameFixedDuration, fn, recorder.BeforeSaveHook)
}
/************************
helpers
************************/
func mediaType(header http.Header) string {
v := header.Get("Content-Type")
media, _, _ := mime.ParseMediaType(v)
return media
}
func sanitizeValues(values map[string][]string, sanitizers map[string]ValueSanitizer, keys utils.StringSet) map[string][]string {
for k := range values {
if !keys.Has(k) {
continue
}
sanitizer, ok := sanitizers[k]
if !ok {
sanitizer = DefaultValueSanitizer()
}
for i := range values[k] {
values[k][i] = sanitizer(values[k][i]).(string)
}
}
return values
}
func sanitizeHeaders(headers http.Header, headerKeys utils.StringSet) http.Header {
return sanitizeValues(headers, HeaderSanitizers, headerKeys)
}
func sanitizeUrl(raw string, queryKeys utils.StringSet) string {
parsed, e := url.Parse(raw)
if e != nil {
return raw
}
var queries url.Values = sanitizeValues(parsed.Query(), QuerySanitizers, queryKeys)
parsed.RawQuery = queries.Encode()
return parsed.String()
}
func sanitizeRequestForm(req *cassette.Request, queryKeys utils.StringSet) string {
req.Form = sanitizeValues(req.Form, QuerySanitizers, queryKeys)
return req.Form.Encode()
}
func sanitizeJsonBody(body string, sanitizers map[string]ValueSanitizer, jsonPaths []parsedJsonPath) string {
if len(jsonPaths) == 0 {
return body
}
root, e := ajson.Unmarshal([]byte(body))
if e != nil {
return body
}
for _, path := range jsonPaths {
nodes, e := ajson.ApplyJSONPath(root, path.Parsed)
if e != nil || len(nodes) == 0 {
continue
}
for _, node := range nodes {
sanitizer, ok := sanitizers[node.Key()]
if !ok {
sanitizer = DefaultValueSanitizer()
}
switch node.Type() {
case ajson.String:
_ = node.Set(sanitizer(node.MustString()))
case ajson.Numeric:
_ = node.Set(sanitizer(node.MustNumeric()))
case ajson.Bool:
_ = node.Set(sanitizer(node.MustBool()))
default:
}
}
}
return root.String()
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ittest
import (
"bytes"
"fmt"
"gopkg.in/dnaeon/go-vcr.v3/cassette"
"io"
"net/http"
"net/url"
"strconv"
)
var errInteractionIDMismatch = fmt.Errorf("HTTP interaction ID doesn't match")
type RecordMatcherOptions func(opt *RecordMatcherOption)
type RecordMatcherOption struct {
// Convenient Options
IgnoreHost bool
FuzzyHeaders []string
FuzzyQueries []string
FuzzyPostForm []string
FuzzyJsonPaths []string
// Advanced Options, if set, will overwrite corresponding convenient options
// Note: directly changing these defaults requires knowledge about golang generics and function casting.
URLMatcher RecordURLMatcherFunc
QueryMatcher RecordQueryMatcherFunc
HeaderMatcher RecordHeaderMatcherFunc
BodyMatchers []RecordBodyMatcher
}
// NewRecordMatcher create a custom RecordMatcherFunc to compare recorded request and actual request.
// By default, the crated matcher compare following:
// - Method, Host, Path are exact match
// - Queries are exact match except for FuzzyRequestQueries
// - Headers are exact match except for FuzzyRequestHeaders
// - Body is compared as JSON or x-www-form-urlencoded Form
//
// Note: In case the request contains random/temporal data in queries/headers/form/JSON, use Fuzzy* options
func NewRecordMatcher(opts ...RecordMatcherOptions) GenericMatcherFunc[*http.Request, cassette.Request] {
opt := resolveMatcherOption(opts)
return func(out *http.Request, record cassette.Request) error {
recordUrl, e := url.Parse(record.URL)
if e != nil {
return fmt.Errorf("invalid recorded URL")
}
if out.Method != record.Method {
return fmt.Errorf("http method mismatch")
}
if e := opt.URLMatcher(out.URL, recordUrl); e != nil {
return e
}
if e := opt.QueryMatcher(out.URL.Query(), recordUrl.Query()); e != nil {
return e
}
if e := opt.HeaderMatcher(out.Header, record.Headers); e != nil {
return e
}
var reqBody []byte
if out.Body != nil && out.Body != http.NoBody {
data, e := io.ReadAll(out.Body)
if e != nil {
return fmt.Errorf("unable to read request's body")
}
reqBody = data
out.Body.Close()
out.Body = io.NopCloser(bytes.NewBuffer(data))
}
// find first supported body matcher and use it
contentType := record.Headers.Get("Content-Type")
for _, matcher := range opt.BodyMatchers {
if !matcher.Support(contentType) {
continue
}
return matcher.Matches(reqBody, []byte(record.Body))
}
return nil
}
}
func resolveMatcherOption(opts []RecordMatcherOptions) *RecordMatcherOption {
opt := RecordMatcherOption{
IgnoreHost: false,
FuzzyHeaders: FuzzyRequestHeaders.Values(),
FuzzyQueries: FuzzyRequestQueries.Values(),
FuzzyPostForm: FuzzyRequestQueries.Values(),
}
for _, fn := range opts {
fn(&opt)
}
if opt.URLMatcher == nil {
opt.URLMatcher = RecordURLMatcherFunc(NewRecordURLMatcher(opt.IgnoreHost))
}
if opt.QueryMatcher == nil {
opt.QueryMatcher = RecordQueryMatcherFunc(NewRecordQueryMatcher(opt.FuzzyQueries...))
}
if opt.HeaderMatcher == nil {
opt.HeaderMatcher = RecordHeaderMatcherFunc(NewRecordHeaderMatcher(opt.FuzzyHeaders...))
}
opt.BodyMatchers = append(opt.BodyMatchers,
NewRecordJsonBodyMatcher(opt.FuzzyJsonPaths...),
NewRecordFormBodyMatcher(opt.FuzzyPostForm...),
NewRecordLiteralBodyMatcher(),
)
return &opt
}
// indexAwareMatcherWrapper is a special matcher wrapper that ensure requests are executed in the recorded order
type indexAwareMatcherWrapper struct {
// count for total actual request have seen
count int
}
func newIndexAwareMatcherWrapper() *indexAwareMatcherWrapper {
return &indexAwareMatcherWrapper{
count: 0,
}
}
// MatcherFunc wrap given delegate with index enforcement
// Note 1: because current httpvcr lib doesn't expose the interaction ID, we stored it in header
//
// using InteractionIndexAwareHook
//
// Note 2: This wrapper doesn't invoke delegate if expected ID doesn't match.
// Note 3: The next expected ID would increase if delegate is a match. This means if recorder couldn't match the
//
// request with currently expected interaction, it would keep waiting on the same interaction
func (w *indexAwareMatcherWrapper) MatcherFunc(delegate RecordMatcherFunc) GenericMatcherFunc[*http.Request, cassette.Request] {
return func(out *http.Request, record cassette.Request) error {
recordId, e := strconv.Atoi(record.Headers.Get(xInteractionIndexHeader))
if e != nil {
recordId = -1
}
seen := len(out.Header.Get(xInteractionSeenHeader)) != 0
if !seen {
// a new request, we adjust the expectation and set the request to be seen
out.Header.Set(xInteractionSeenHeader, "true")
w.count++
}
// do interaction match first
if w.count != recordId+1 {
return errInteractionIDMismatch
}
// invoke delegate, increase counter if applicable
return delegate(out, record)
}
}
/*********************
Matcher Options
*********************/
// IgnoreHost returns RecordMatcherOptions that ignore host during record matching
func IgnoreHost() RecordMatcherOptions {
return func(opt *RecordMatcherOption) {
opt.IgnoreHost = true
}
}
// FuzzyHeaders returns RecordMatcherOptions that ignore header values of given names during record matching
// Note: still check if the header exists, only value comparison is skipped
func FuzzyHeaders(headers ...string) RecordMatcherOptions {
return func(opt *RecordMatcherOption) {
opt.FuzzyHeaders = append(opt.FuzzyHeaders, headers...)
}
}
// FuzzyQueries returns RecordMatcherOptions that ignore query value of given keys during record matching
// Note: still check if the value exists, only value comparison is skipped.
// This function dosen't consider POST form data. Use FuzzyForm for both Queries and POST form data
func FuzzyQueries(queries ...string) RecordMatcherOptions {
return FuzzyForm(queries...)
}
// FuzzyForm returns RecordMatcherOptions that ignore form values (in queries and post body if applicable) of given keys during record matching
// Note: still check if the value exists, only value comparison is skipped
func FuzzyForm(formKeys ...string) RecordMatcherOptions {
return func(opt *RecordMatcherOption) {
opt.FuzzyQueries = append(opt.FuzzyQueries, formKeys...)
opt.FuzzyPostForm = append(opt.FuzzyPostForm, formKeys...)
}
}
// FuzzyJsonPaths returns RecordMatcherOptions that ignore fields in JSON body that matching the given JSONPaths
// JSONPath Syntax: https://goessner.net/articles/JsonPath/
func FuzzyJsonPaths(jsonPaths ...string) RecordMatcherOptions {
return func(opt *RecordMatcherOption) {
opt.FuzzyJsonPaths = append(opt.FuzzyJsonPaths, jsonPaths...)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ittest
import (
"fmt"
"github.com/spyzhov/ajson"
"mime"
"net/url"
"reflect"
"strings"
)
type RecordLiteralBodyMatcher GenericMatcherFunc[[]byte, []byte]
func (m RecordLiteralBodyMatcher) Support(_ string) bool {
return true
}
func (m RecordLiteralBodyMatcher) Matches(out []byte, record []byte) error {
return m(out, record)
}
// NewRecordLiteralBodyMatcher returns RecordBodyMatcher that matches request bodies literally
func NewRecordLiteralBodyMatcher() RecordBodyMatcher {
return RecordLiteralBodyMatcher(func(out []byte, record []byte) error {
if len(out) != len(record) {
return fmt.Errorf("body lengths mismatch")
}
for i := range record {
if out[i] != record[i] {
return fmt.Errorf("body content mismatch")
}
}
return nil
})
}
type RecordFormBodyMatcher GenericMatcherFunc[[]byte, []byte]
func (m RecordFormBodyMatcher) Support(contentType string) bool {
media, _, e := mime.ParseMediaType(contentType)
return e == nil && media == "application/x-www-form-urlencoded"
}
func (m RecordFormBodyMatcher) Matches(out []byte, record []byte) error {
return m(out, record)
}
// NewRecordFormBodyMatcher returns RecordBodyMatcher that matches request bodies as application/x-www-form-urlencoded.
// any value in the fuzzyKeys is not compared. But outgoing body need to have all keys contained in the record body
func NewRecordFormBodyMatcher(fuzzyKeys...string) RecordBodyMatcher {
valuesMatcher := newValuesMatcher("form body", nil, fuzzyKeys...)
return RecordLiteralBodyMatcher(func(out []byte, record []byte) error {
outForm := parseFormBody(out)
rForm := parseFormBody(record)
return valuesMatcher(outForm, rForm)
})
}
type RecordJsonBodyMatcher GenericMatcherFunc[[]byte, []byte]
func (m RecordJsonBodyMatcher) Support(contentType string) bool {
media, _, e := mime.ParseMediaType(contentType)
return e == nil && media == "application/json"
}
func (m RecordJsonBodyMatcher) Matches(out []byte, record []byte) error {
return m(out, record)
}
// NewRecordJsonBodyMatcher returns a RecordBodyMatcher that matches JSON body of recorded and outgoing request.
// Values of any field matching the optional fuzzyJsonPaths is not compared, but outgoing request body must contain
// all fields that the record contains
func NewRecordJsonBodyMatcher(fuzzyJsonPaths ...string) RecordBodyMatcher {
parsedPaths := parseJsonPaths(fuzzyJsonPaths)
return RecordJsonBodyMatcher(func(out []byte, record []byte) error {
rRoot, rMatched, e := parseJsonWithFilter(record, parsedPaths)
if e != nil {
return e
}
lRoot, lMatched, e := parseJsonWithFilter(out, parsedPaths)
if e != nil {
return e
}
// first compare filtered nodes, they need to be identical
if !reflect.DeepEqual(lRoot, rRoot) {
return fmt.Errorf("JSON body content mismatch")
}
// second, check if all matched fuzzy json paths in record exists in the outgoing body
OUTER:
for _, p := range rMatched {
for _, lp := range lMatched {
if p.Value == lp.Value {
continue OUTER
}
}
return fmt.Errorf("JSON body content mismatch: missing [%s]", p.Value)
}
return nil
})
}
/**********************
helpers
**********************/
type parsedJsonPath struct {
Value string
Parsed []string
}
func parseJsonPaths(jsonPaths []string) (parsed []parsedJsonPath) {
parsed = make([]parsedJsonPath, 0, len(jsonPaths))
for _, path := range jsonPaths {
p, e := ajson.ParseJSONPath(path)
if e != nil {
panic(e)
}
parsed = append(parsed, parsedJsonPath{Value: path, Parsed: p})
}
return
}
func parseJsonWithFilter(data []byte, jsonPaths []parsedJsonPath) (filtered interface{}, filteredPaths []parsedJsonPath, err error) {
root, e := ajson.Unmarshal(data)
if e != nil {
return nil, nil, e
}
filteredPaths = make([]parsedJsonPath, 0, len(jsonPaths))
for _, path := range jsonPaths {
nodes, e := ajson.ApplyJSONPath(root, path.Parsed)
if e != nil || len(nodes) == 0 {
continue
}
filteredPaths = append(filteredPaths, path)
for _, node := range nodes {
_ = node.Delete()
}
}
filtered, err = root.Unpack()
return
}
func parseFormBody(data []byte) url.Values {
parsed := url.Values{}
vals := strings.Split(string(data), "&")
for _, pair := range vals {
kv := strings.SplitN(pair, "=", 2)
if len(kv) < 2 {
continue
}
if v , e := url.QueryUnescape(kv[1]); e == nil {
parsed.Add(kv[0], v)
}
}
return parsed
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ittest
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils"
"net/http"
"net/url"
)
// NewRecordURLMatcher returns RecordURLMatcherFunc that compares Method, Path, Host and Port
func NewRecordURLMatcher(ignoreHost bool) GenericMatcherFunc[*url.URL, *url.URL] {
return func(out *url.URL, record *url.URL) error {
if out.Path != record.Path {
return fmt.Errorf("http path mismatch")
}
if !ignoreHost && out.Host != record.Host {
return fmt.Errorf("http host mismatch")
}
return nil
}
}
// NewRecordQueryMatcher returns RecordQueryMatcherFunc that compare keys and values of recorded and actual queries
// Any query value is ignored if its key is in the optional fuzzyKeys
func NewRecordQueryMatcher(fuzzyKeys ...string) GenericMatcherFunc[url.Values, url.Values] {
return newValuesMatcher("query", nil, fuzzyKeys...)
}
// NewRecordHeaderMatcher returns RecordHeaderMatcherFunc that compare keys and values of recorded and actual queries
// Any header value is ignored if its key is in the optional fuzzyKeys
func NewRecordHeaderMatcher(fuzzyKeys ...string) GenericMatcherFunc[http.Header, http.Header] {
delegate := newValuesMatcher("header", IgnoredRequestHeaders, fuzzyKeys...)
return func(out http.Header, record http.Header) error {
return delegate(url.Values(out), url.Values(record))
}
}
/**********************
helpers
**********************/
// newValuesMatcher returns GenericMatcherFunc[url.Values, url.Values] that compare keys and values of given url.Values
// Any value is ignored if its key is in the optional fuzzyKeys
func newValuesMatcher(name string, ignoredKeys utils.StringSet, fuzzyKeys ...string) GenericMatcherFunc[url.Values, url.Values] {
fuzzyK := utils.NewStringSet(fuzzyKeys...)
return func(out url.Values, record url.Values) error {
for k, rv := range record {
if ignoredKeys != nil && ignoredKeys.Has(k) {
continue
}
exactV := !fuzzyK.Has(k)
ov, ok := out[k]
if !ok || exactV && len(ov) != len(rv) {
return fmt.Errorf("http %s [%s] missing", name, k)
}
if !exactV {
continue
}
// values
for i, v := range ov {
if rv[i] != v {
return fmt.Errorf("http %s [%s] mismatch", name, k)
}
}
}
return nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ittest
import (
"github.com/cisco-open/go-lanai/pkg/integrate/httpclient"
secit "github.com/cisco-open/go-lanai/pkg/integrate/security"
"github.com/cisco-open/go-lanai/pkg/integrate/security/scope"
"github.com/cisco-open/go-lanai/pkg/integrate/security/seclient"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"go.uber.org/fx"
"gopkg.in/dnaeon/go-vcr.v3/cassette"
"gopkg.in/dnaeon/go-vcr.v3/recorder"
"net/http"
"strings"
"time"
)
func WithRecordedScopes() test.Options {
fxOpts := []fx.Option{
fx.Provide(provideScopeDI),
fx.Provide(provideScopeVCROptions),
}
opts := []test.Options{
apptest.WithModules(scope.Module, seclient.Module),
apptest.WithFxOptions(fxOpts...),
}
return func(opt *test.T) {
for _, fn := range opts {
fn(opt)
}
}
}
/*************************
Providers
*************************/
type scopeDI struct {
fx.In
ItProperties secit.SecurityIntegrationProperties
HttpClient httpclient.Client
Recorder *recorder.Recorder `optional:"true"`
}
type scopeDIOut struct {
fx.Out
TokenReader oauth2.TokenStoreReader
}
func provideScopeDI(di scopeDI) scopeDIOut {
tokenReader := NewRemoteTokenStoreReader(func(opt *RemoteTokenStoreOption) {
opt.SkipRemoteCheck = true
opt.HttpClient = di.HttpClient
opt.BaseUrl = di.ItProperties.Endpoints.BaseUrl
opt.ServiceName = di.ItProperties.Endpoints.ServiceName
opt.Scheme = di.ItProperties.Endpoints.Scheme
opt.ContextPath = di.ItProperties.Endpoints.ContextPath
opt.ClientId = di.ItProperties.Client.ClientId
opt.ClientSecret = di.ItProperties.Client.ClientSecret
if di.Recorder != nil {
opt.HttpClientConfig = &httpclient.ClientConfig{
HTTPClient: di.Recorder.GetDefaultClient(),
}
}
})
return scopeDIOut{
TokenReader: tokenReader,
}
}
type scopeVCROptionsOut struct {
fx.Out
VCROptions HTTPVCROptions `group:"http-vcr"`
}
func provideScopeVCROptions() scopeVCROptionsOut {
return scopeVCROptionsOut{
VCROptions: HttpRecorderHooks(extendedTokenValidityHook()),
}
}
/*************************
Additional Hooks
*************************/
// extendedTokenValidityHook HTTP VCR hook that extend token validity to a distant future.
// During scope switching, token's expiry time is used to determine if token need to be refreshed.
// This would cause inconsistent HTTP interactions between recording time and replay time (after token expires)
// "expiry" and "expires_in" are JSON fields in `/v2/token` response and `exp` is a standard claim in `/v2/check_token` response
func extendedTokenValidityHook() RecorderHook {
longValidity := 100 * 24 * 365 * time.Hour
expiry := time.Now().Add(longValidity)
tokenBodySanitizers := map[string]ValueSanitizer{
"expiry": SubstituteValueSanitizer(expiry.Format(time.RFC3339)),
"expires_in": SubstituteValueSanitizer(longValidity.Seconds()),
"exp": SubstituteValueSanitizer(expiry.Unix()),
}
tokenBodyJsonPaths := parseJsonPaths([]string{"$.expiry", "$.expires_in", "$.exp"})
fn := func(i *cassette.Interaction) error {
if i.Response.Code != http.StatusOK ||
!strings.Contains(i.Request.URL, "/v2/token") && !strings.Contains(i.Request.URL, "/v2/check_token") {
return nil
}
i.Response.Body = sanitizeJsonBody(i.Response.Body, tokenBodySanitizers, tokenBodyJsonPaths)
return nil
}
return NewRecorderHook("extend-token-validity", fn, recorder.BeforeResponseReplayHook)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package ittest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/integrate/httpclient"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/misc"
"github.com/cisco-open/go-lanai/pkg/security/oauth2/jwt"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/test/sectest"
"net/http"
"net/url"
)
//nolint:gosec // G101: Potential hardcoded credentials, false positive
const CheckTokenPath = `/v2/check_token`
// RemoteTokenStoreReader implements oauth2.TokenStoreReader that leverage /check_token endpoint to load authentication
// Note: this implementation is not mocks. With proper refactoring, it can be potentially used in production
type RemoteTokenStoreReader struct {
// JwtDecoder optional, when provided, token signature is pre-checked before sent to remote auth service
JwtDecoder jwt.JwtDecoder
// SkipRemoteCheck, if set to true, skip the remote check when JwtDecoder is provided and context details is not required.
SkipRemoteCheck bool
// HttpClient httpclient.Client to use for remote token check
HttpClient httpclient.Client
ClientId string
ClientSecret string
}
type RemoteTokenStoreOptions func(opt *RemoteTokenStoreOption)
type RemoteTokenStoreOption struct {
JwtDecoder jwt.JwtDecoder
HttpClient httpclient.Client
HttpClientConfig *httpclient.ClientConfig
BaseUrl string
ServiceName string // auth service's name for LB
Scheme string
ContextPath string
ClientId string
ClientSecret string
SkipRemoteCheck bool
}
func NewRemoteTokenStoreReader(opts ...RemoteTokenStoreOptions) oauth2.TokenStoreReader {
var opt RemoteTokenStoreOption
for _, fn := range opts {
fn(&opt)
}
var client httpclient.Client
var e error
if opt.BaseUrl != "" {
client, e = opt.HttpClient.WithBaseUrl(opt.BaseUrl)
} else {
client, e = opt.HttpClient.WithService(opt.ServiceName, func(sdOpt *httpclient.SDOption) {
sdOpt.Scheme = opt.Scheme
sdOpt.ContextPath = opt.ContextPath
})
}
if e != nil {
panic(e)
}
if opt.HttpClientConfig != nil {
client = client.WithConfig(opt.HttpClientConfig)
}
return &RemoteTokenStoreReader{
JwtDecoder: opt.JwtDecoder,
SkipRemoteCheck: opt.SkipRemoteCheck,
HttpClient: client.WithConfig(opt.HttpClientConfig),
ClientId: opt.ClientId,
ClientSecret: opt.ClientSecret,
}
}
func (r *RemoteTokenStoreReader) ReadAuthentication(ctx context.Context, tokenValue string, hint oauth2.TokenHint) (oauth2.Authentication, error) {
switch hint {
case oauth2.TokenHintAccessToken:
return r.readAuthenticationFromAccessToken(ctx, tokenValue)
default:
return nil, oauth2.NewUnsupportedTokenTypeError(fmt.Sprintf("token type [%s] is not supported", hint.String()))
}
}
func (r *RemoteTokenStoreReader) ReadAccessToken(ctx context.Context, value string) (oauth2.AccessToken, error) {
token, e := r.readAccessToken(ctx, value, nil)
if e != nil {
return nil, oauth2.NewInvalidAccessTokenError("token is invalid", e)
}
return token, nil
}
//nolint:staticcheck // this feature is not fully implemented yet
func (r *RemoteTokenStoreReader) ReadRefreshToken(ctx context.Context, value string) (oauth2.RefreshToken, error) {
token, e := r.parseRefreshToken(ctx, value)
switch {
case e != nil:
return nil, oauth2.NewInvalidGrantError("refresh token is invalid", e)
case token.WillExpire() && token.Expired():
return nil, oauth2.NewInvalidGrantError("refresh token is expired")
}
return token, nil
}
func (r *RemoteTokenStoreReader) readAccessToken(ctx context.Context, value string, detailedClaims *misc.CheckTokenClaims) (*oauth2.DefaultAccessToken, error) {
var basicClaims oauth2.BasicClaims
// pre-check signature if possible
if r.JwtDecoder != nil {
if e := r.JwtDecoder.DecodeWithClaims(ctx, value, &basicClaims); e != nil {
return nil, e
}
}
requireDetails := detailedClaims != nil || len(basicClaims.Id) == 0
// Note, we only skip revocation check when we have token claims, details is not required and SkipRemoteCheck is true
if r.SkipRemoteCheck && !requireDetails {
return r.createAccessToken(&basicClaims, value), nil
}
// perform remote check
if requireDetails && detailedClaims == nil {
detailedClaims = &misc.CheckTokenClaims{}
}
if e := r.remoteAccessTokenCheck(ctx, value, detailedClaims); e != nil {
return nil, e
}
return r.createAccessToken(&detailedClaims.BasicClaims, value), nil
}
//nolint:staticcheck // this feature is not fully implemented yet
func (r *RemoteTokenStoreReader) parseRefreshToken(_ context.Context, _ string) (*oauth2.DefaultRefreshToken, error) {
return nil, fmt.Errorf("remote refresh token validation is not supported")
}
func (r *RemoteTokenStoreReader) readAuthenticationFromAccessToken(ctx context.Context, tokenValue string) (oauth2.Authentication, error) {
// parse JWT token
var claims misc.CheckTokenClaims
token, e := r.readAccessToken(ctx, tokenValue, &claims)
if e != nil {
return nil, e
}
// load context details
details := r.createSecurityDetails(&claims)
if e != nil {
return nil, oauth2.NewInvalidAccessTokenError("token unknown", e)
}
// reconstruct request
request := r.createOAuth2Request(&claims, details)
// reconstruct user auth if available
var userAuth security.Authentication
if claims.Subject != "" {
userAuth = r.createUserAuthentication(&claims, details)
}
return oauth2.NewAuthentication(func(opt *oauth2.AuthOption) {
opt.Request = request
opt.UserAuth = userAuth
opt.Token = token
opt.Details = details
}), nil
}
/*****************
Helpers
*****************/
func (r *RemoteTokenStoreReader) remoteAccessTokenCheck(ctx context.Context, value string, dest *misc.CheckTokenClaims) error {
form := url.Values{
"token": []string{value},
"token_type_hint": []string{"access_token"},
"no_details": []string{fmt.Sprintf("%v", dest == nil)},
}
req := httpclient.NewRequest(CheckTokenPath, http.MethodPost,
httpclient.WithUrlEncodedBody(form),
httpclient.WithBasicAuth(r.ClientId, r.ClientSecret),
)
claims := dest
if dest == nil {
claims = &misc.CheckTokenClaims{}
}
_, e := r.HttpClient.Execute(ctx, req, httpclient.JsonBody(claims))
if e != nil {
return e
}
if claims.Active == nil || !*claims.Active {
return fmt.Errorf("invalid token")
}
return nil
}
func (r *RemoteTokenStoreReader) createAccessToken(claims *oauth2.BasicClaims, value string) *oauth2.DefaultAccessToken {
token := oauth2.NewDefaultAccessToken(value)
token.SetExpireTime(claims.ExpiresAt)
token.SetIssueTime(claims.IssuedAt)
token.SetScopes(claims.Scopes.Copy())
token.SetClaims(claims)
return token
}
func (r *RemoteTokenStoreReader) createSecurityDetails(claims *misc.CheckTokenClaims) security.ContextDetails {
return sectest.NewMockedSecurityDetails(func(d *sectest.SecurityDetailsMock) {
*d = sectest.SecurityDetailsMock{
Username: claims.Username,
UserId: claims.UserId,
TenantExternalId: claims.TenantExternalId,
TenantId: claims.TenantId,
ProviderName: claims.ProviderName,
ProviderId: claims.ProviderId,
ProviderDisplayName: claims.ProviderDisplayName,
ProviderDescription: claims.ProviderDescription,
ProviderEmail: claims.ProviderEmail,
ProviderNotificationType: claims.ProviderNotificationType,
Exp: claims.ExpiresAt,
Iss: claims.IssuedAt,
Permissions: claims.Permissions,
Tenants: claims.AssignedTenants,
OrigUsername: claims.OrigUsername,
UserFirstName: claims.FirstName,
UserLastName: claims.LastName,
KVs: map[string]interface{}{},
}
})
}
func (r *RemoteTokenStoreReader) createOAuth2Request(claims *misc.CheckTokenClaims, details security.ContextDetails) oauth2.OAuth2Request {
clientId := claims.ClientId
if clientId == "" && claims.Audience != nil && len(claims.Audience) != 0 {
clientId = utils.StringSet(claims.Audience).Values()[0]
}
params := map[string]string{}
reqParams, _ := details.Value(oauth2.DetailsKeyRequestParams)
if m, ok := reqParams.(map[string]interface{}); ok {
for k, v := range m {
switch s := v.(type) {
case string:
params[k] = s
}
}
}
ext := claims.Values()
reqExt, _ := details.Value(oauth2.DetailsKeyRequestExt)
if m, ok := reqExt.(map[string]interface{}); ok {
for k, v := range m {
ext[k] = v
}
}
return oauth2.NewOAuth2Request(func(opt *oauth2.RequestDetails) {
opt.Parameters = params
opt.ClientId = clientId
opt.Scopes = claims.Scopes
opt.Approved = true
opt.Extensions = ext
//opt.GrantType =
//opt.RedirectUri =
//opt.ResponseTypes =
})
}
func (r *RemoteTokenStoreReader) createUserAuthentication(claims *misc.CheckTokenClaims, details security.ContextDetails) security.Authentication {
permissions := map[string]interface{}{}
for k := range details.Permissions() {
permissions[k] = true
}
return oauth2.NewUserAuthentication(func(opt *oauth2.UserAuthOption) {
opt.Principal = claims.Subject
opt.Permissions = permissions
opt.State = security.StateAuthenticated
opt.Details = map[string]interface{}{}
if claims != nil {
opt.Details = claims.Values()
}
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafkatest
import (
"context"
"github.com/cisco-open/go-lanai/pkg/kafka"
"github.com/cisco-open/go-lanai/pkg/log"
)
type MessageRecord struct {
Topic string
Payload interface{}
}
// MessageRecorder interface for retrieve messages produced by MockedProducer
type MessageRecorder interface {
Reset()
Records(topic string) []*MessageRecord
AllRecords() []*MessageRecord
}
type messageRecorder interface {
MessageRecorder
Record(msg *MessageRecord)
}
// MessageMocker interface for mocking incoming messages.
type MessageMocker interface {
Mock(ctx context.Context, topic string, msg *kafka.Message) error
MockWithGroup(ctx context.Context, topic, group string, msg *kafka.Message) error
}
type msgLogger struct {
logger log.ContextualLogger
level log.LoggingLevel
}
func (l msgLogger) WithLevel(level log.LoggingLevel) kafka.MessageLogger {
return msgLogger{
logger: l.logger,
level: level,
}
}
func (l msgLogger) LogSentMessage(ctx context.Context, msg interface{}) {
l.logger.WithContext(ctx).WithLevel(l.level).Printf(`Sent: %v`, msg)
}
func (l msgLogger) LogReceivedMessage(ctx context.Context, msg interface{}) {
l.logger.WithContext(ctx).WithLevel(l.level).Printf(`Received: %v`, msg)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafkatest
import (
"context"
"github.com/cisco-open/go-lanai/pkg/kafka"
"github.com/cisco-open/go-lanai/pkg/utils"
"go.uber.org/fx"
"sync"
)
type mockedBinderOut struct {
fx.Out
Binder kafka.Binder
Mock *MockedBinder
Recorder MessageRecorder
Mocker MessageMocker
}
func provideMockedBinder() mockedBinderOut {
mock := MockedBinder{
producers: make(map[string]*MockedProducer),
subscribers: make(map[string]*MockedSubscriber),
consumers: make(map[string]map[string]*MockedConsumer),
}
return mockedBinderOut{
Binder: &mock,
Mock: &mock,
Recorder: &mock,
Mocker: &mock,
}
}
// MockedBinder implements kafka.Binder and messageRecorder
type MockedBinder struct {
mtx sync.Mutex
producers map[string]*MockedProducer
subscribers map[string]*MockedSubscriber
consumers map[string]map[string]*MockedConsumer
recordings []*MessageRecord
}
func (b *MockedBinder) Produce(topic string, _ ...kafka.ProducerOptions) (kafka.Producer, error) {
b.mtx.Lock()
defer b.mtx.Unlock()
ret, ok := b.producers[topic]
if !ok {
ret = NewMockedProducer(topic, b)
b.producers[topic] = ret
}
return ret, nil
}
func (b *MockedBinder) Subscribe(topic string, _ ...kafka.ConsumerOptions) (kafka.Subscriber, error) {
b.mtx.Lock()
defer b.mtx.Unlock()
ret, ok := b.subscribers[topic]
if !ok {
ret = NewMockedSubscriber(topic)
b.subscribers[topic] = ret
}
return ret, nil
}
func (b *MockedBinder) Consume(topic string, group string, _ ...kafka.ConsumerOptions) (kafka.GroupConsumer, error) {
b.mtx.Lock()
defer b.mtx.Unlock()
grouped, ok := b.consumers[topic]
if !ok {
grouped = make(map[string]*MockedConsumer)
b.consumers[topic] = grouped
}
ret, ok := grouped[group]
if !ok {
ret = NewMockedConsumer(topic, group)
grouped[group] = ret
}
return ret, nil
}
func (b *MockedBinder) ListTopics() []string {
b.mtx.Lock()
defer b.mtx.Unlock()
topics := utils.NewStringSet()
for k := range b.producers {
topics.Add(k)
}
for k := range b.subscribers {
topics.Add(k)
}
for k := range b.consumers {
topics.Add(k)
}
return topics.Values()
}
func (b *MockedBinder) Reset() {
b.mtx.Lock()
defer b.mtx.Unlock()
b.recordings = nil
}
func (b *MockedBinder) Records(topic string) (ret []*MessageRecord) {
b.mtx.Lock()
defer b.mtx.Unlock()
ret = make([]*MessageRecord, 0, len(b.recordings))
for _, r := range b.recordings {
if r.Topic == topic {
ret = append(ret, r)
}
}
return
}
func (b *MockedBinder) AllRecords() (ret []*MessageRecord) {
b.mtx.Lock()
defer b.mtx.Unlock()
ret = make([]*MessageRecord, len(b.recordings))
copy(ret, b.recordings)
return
}
func (b *MockedBinder) Record(record *MessageRecord) {
b.mtx.Lock()
defer b.mtx.Unlock()
b.recordings = append(b.recordings, record)
}
func (b *MockedBinder) Mock(ctx context.Context, topic string, msg *kafka.Message) error {
msgCtx := b.mockMessageContext(ctx, topic, msg)
b.mtx.Lock()
defer b.mtx.Unlock()
dispatcher, ok := b.subscribers[topic]
if !ok {
return nil
}
return dispatcher.Dispatch(msgCtx)
}
func (b *MockedBinder) MockWithGroup(ctx context.Context, topic, group string, msg *kafka.Message) error {
msgCtx := b.mockMessageContext(ctx, topic, msg)
b.mtx.Lock()
defer b.mtx.Unlock()
consumers, ok := b.consumers[topic]
if !ok {
return nil
}
dispatcher, ok := consumers[group]
if !ok {
return nil
}
return dispatcher.Dispatch(msgCtx)
}
func (b *MockedBinder) mockMessageContext(ctx context.Context, topic string, msg *kafka.Message) *kafka.MessageContext {
return &kafka.MessageContext{
Context: ctx,
Source: b,
Topic: topic,
Message: *msg,
RawMessage: msg,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafkatest
import (
"github.com/cisco-open/go-lanai/pkg/kafka"
)
type MockedConsumer struct {
kafka.Dispatcher
T string
G string
}
func NewMockedConsumer(topic, group string) *MockedConsumer {
return &MockedConsumer{
Dispatcher: kafka.Dispatcher{
Logger: messageLogger,
},
T: topic,
G: group,
}
}
func (c *MockedConsumer) Topic() string {
return c.T
}
func (c *MockedConsumer) Group() string {
return c.G
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafkatest
import (
"context"
"github.com/cisco-open/go-lanai/pkg/kafka"
)
type MockedProducer struct {
T string
Recorder messageRecorder
}
func NewMockedProducer(topic string, recorder messageRecorder) *MockedProducer{
return &MockedProducer{
T: topic,
Recorder: recorder,
}
}
func (p *MockedProducer) Topic() string {
return p.T
}
func (p *MockedProducer) Send(_ context.Context, message interface{}, _ ...kafka.MessageOptions) error {
p.Recorder.Record(&MessageRecord{
Topic: p.T,
Payload: message,
})
return nil
}
func(p *MockedProducer) ReadyCh() <-chan struct{} {
// always ready
ch := make(chan struct{}, 1)
close(ch)
return ch
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafkatest
import (
"github.com/cisco-open/go-lanai/pkg/kafka"
"github.com/cisco-open/go-lanai/pkg/utils"
)
type MockedSubscriber struct {
kafka.Dispatcher
T string
Parts []int32
}
func NewMockedSubscriber(topic string) *MockedSubscriber {
return &MockedSubscriber{
Dispatcher: kafka.Dispatcher{
Logger: messageLogger,
},
T: topic,
Parts: []int32{int32(utils.RandomIntN(0xffff))},
}
}
func (s *MockedSubscriber) Topic() string {
return s.T
}
func (s *MockedSubscriber) Partitions() []int32 {
return s.Parts
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package kafkatest
import (
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"go.uber.org/fx"
)
var logger = log.New("KafkaTest")
var messageLogger = msgLogger{
logger: logger,
level: log.LevelInfo,
}
// WithMockedBinder returns a test.Options that provides mocked kafka.Binder and a MessageRecorder.
// Tests can wire the MessageRecorder and verify invocation of kafka.Producer
// Note: The main purpose of this test configuration is to fulfill dependency injection and validate kafka.Producer is
// invoked as expected. It doesn't validate/invoke any message options such as ValueEncoder or Key, nor does it
// respect any binding configuration
func WithMockedBinder() test.Options {
testOpts := []test.Options{
apptest.WithFxOptions(
fx.Provide(provideMockedBinder),
),
}
return test.WithOptions(testOpts...)
}
// Code generated by MockGen. DO NOT EDIT.
// Source: ../pkg/security/ctx.go
// Package mock_security is a generated GoMock package.
package authmock
import (
security "github.com/cisco-open/go-lanai/pkg/security"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockAuthentication is a mock of Authentication interface
type MockAuthentication struct {
ctrl *gomock.Controller
recorder *MockAuthenticationMockRecorder
}
// MockAuthenticationMockRecorder is the mock recorder for MockAuthentication
type MockAuthenticationMockRecorder struct {
mock *MockAuthentication
}
// NewMockAuthentication creates a new mock instance
func NewMockAuthentication(ctrl *gomock.Controller) *MockAuthentication {
mock := &MockAuthentication{ctrl: ctrl}
mock.recorder = &MockAuthenticationMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockAuthentication) EXPECT() *MockAuthenticationMockRecorder {
return m.recorder
}
// Principal mocks base method
func (m *MockAuthentication) Principal() interface{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Principal")
ret0, _ := ret[0].(interface{})
return ret0
}
// Principal indicates an expected call of Principal
func (mr *MockAuthenticationMockRecorder) Principal() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Principal", reflect.TypeOf((*MockAuthentication)(nil).Principal))
}
// Permissions mocks base method
func (m *MockAuthentication) Permissions() security.Permissions {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Permissions")
ret0, _ := ret[0].(map[string]interface{})
return ret0
}
// Permissions indicates an expected call of Permissions
func (mr *MockAuthenticationMockRecorder) Permissions() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Permissions", reflect.TypeOf((*MockAuthentication)(nil).Permissions))
}
// State mocks base method
func (m *MockAuthentication) State() security.AuthenticationState {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "State")
ret0, _ := ret[0].(security.AuthenticationState)
return ret0
}
// State indicates an expected call of State
func (mr *MockAuthenticationMockRecorder) State() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockAuthentication)(nil).State))
}
// Details mocks base method
func (m *MockAuthentication) Details() interface{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Details")
ret0, _ := ret[0].(interface{})
return ret0
}
// Details indicates an expected call of Details
func (mr *MockAuthenticationMockRecorder) Details() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Details", reflect.TypeOf((*MockAuthentication)(nil).Details))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package dsyncmock
import (
"context"
"github.com/cisco-open/go-lanai/pkg/dsync"
"go.uber.org/fx"
"sync"
)
type SimpleSyncManagerMock struct {}
type NoopOut struct {
fx.Out
TestSyncManager dsync.SyncManager `group:"dsync"`
}
func ProvideNoopSyncManager() NoopOut {
return NoopOut{
TestSyncManager: SimpleSyncManagerMock{},
}
}
func (m SimpleSyncManagerMock) Lock(key string, _ ...dsync.LockOptions) (dsync.Lock, error) {
return &AlwaysLockMock{key: key}, nil
}
type AlwaysLockMock struct {
mtx sync.Mutex
key string
ch chan struct{}
}
func (l *AlwaysLockMock) Key() string {
return l.key
}
func (l *AlwaysLockMock) Lock(_ context.Context) error {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.ch == nil {
l.ch = make(chan struct{}, 1)
}
return nil
}
func (l *AlwaysLockMock) TryLock(ctx context.Context) error {
return l.Lock(ctx)
}
func (l *AlwaysLockMock) Release() error {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.ch != nil {
close(l.ch)
l.ch = nil
}
return nil
}
func (l *AlwaysLockMock) Lost() <-chan struct{} {
l.mtx.Lock()
defer l.mtx.Unlock()
return l.ch
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package internal is an internal package that help to test generated mocks
package internal
import (
"github.com/golang/mock/gomock"
"github.com/onsi/gomega"
. "github.com/onsi/gomega"
"reflect"
)
// AssertGoMockGenerated check given "mock" implemented interface T and invocation of interface method is recorded by ctrl
func AssertGoMockGenerated[T any](g *gomega.WithT, mock interface{}, ctrl *gomock.Controller) {
var targetInterface T
rtI := reflect.TypeOf(&targetInterface).Elem()
rv := reflect.ValueOf(mock)
// get recorder by invoke EXPECT()
expectFn := rv.MethodByName("EXPECT")
g.Expect(expectFn.IsZero()).To(BeFalse(), "mock %T should have EXPECT()", mock)
g.Expect(expectFn.Type().NumIn()).To(BeZero(), "mock %T should have EXPECT() with 0 input", mock)
out := InvokeFunc(expectFn, []reflect.Value{})
g.Expect(out).To(HaveLen(1), "mock %T should have EXPECT() with 1 output", mock)
rExpect := out[0]
// go through interfaces
for i := 0; i < rtI.NumMethod(); i++ {
name := rtI.Method(i).Name
rm, ok := rv.Type().MethodByName(name)
g.Expect(ok).To(BeTrue(), "actual mock should implement method [%s]", name)
AssertGoMockGeneratedMethod(g, rm, rv, rExpect)
}
}
func AssertGoMockGeneratedMethod(g *gomega.WithT, method reflect.Method, receiver reflect.Value, expect reflect.Value) {
var out []reflect.Value
g.Expect(method.IsExported()).To(BeTrue(), "method [%s] should be exported", method.Name)
actualFn := receiver.MethodByName(method.Name)
g.Expect(actualFn.IsZero()).To(BeFalse(), "mock should have matching method [%s]", method.Name)
expectFn := expect.MethodByName(method.Name)
g.Expect(expectFn.IsZero()).To(BeFalse(), "EXPECT() should have matching method [%s]", method.Name)
// prepare input params
ft := method.Func.Type()
actualIn := make([]reflect.Value, 0, ft.NumIn())
expectIn := make([]reflect.Value, 0, ft.NumIn())
// Note: the first param in method is receiver
for i, isVariadic, lastIdx := 1, ft.IsVariadic(), ft.NumIn()-1; i <= lastIdx; i++ {
v := MockValue(ft.In(i), isVariadic && i == lastIdx)
actualIn = append(actualIn, v)
if isVariadic && i == lastIdx {
// this is a varargs, the expectIn should be []interface{}{gomock.Any()}
expectIn = append(expectIn, reflect.ValueOf([]interface{}{gomock.Any()}))
} else {
expectIn = append(expectIn, reflect.ValueOf(gomock.Eq(v.Interface())))
}
}
// mock behavior
out = InvokeFunc(expectFn, expectIn)
g.Expect(out).To(HaveLen(1), "EXPECT().%s() should return 1 item", method.Name)
g.Expect(out[0].Interface()).To(BeAssignableToTypeOf(&gomock.Call{}), "EXPECT().%s() should return %T", method.Name, &gomock.Call{})
mockCall := out[0].Interface().(*gomock.Call)
mockedRet := make([]interface{}, ft.NumOut())
for i := 0; i < ft.NumOut(); i++ {
mockedRet[i] = MockValue(ft.Out(i), false).Interface()
}
mockCall.Return(mockedRet...)
// call actual method
out = InvokeFunc(actualFn, actualIn)
g.Expect(out).To(HaveLen(ft.NumOut()), "method [%s] should return correct number of parameters", method.Name)
}
func InvokeFunc(fn reflect.Value, in []reflect.Value) []reflect.Value {
if fn.Type().IsVariadic() {
return fn.CallSlice(in)
} else {
return fn.Call(in)
}
}
// MockValue mock a value of given types.
// if this is a slice and varargs, the slice contains one element
func MockValue(typ reflect.Type, isVarargs bool) reflect.Value {
switch typ.Kind() {
case reflect.Pointer:
return reflect.New(typ.Elem())
case reflect.Slice:
if isVarargs {
return reflect.MakeSlice(typ, 1, 1)
}
return reflect.MakeSlice(typ, 0, 0)
default:
return reflect.Indirect(reflect.New(typ))
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package mocks
import (
"container/list"
"context"
"errors"
"github.com/google/uuid"
)
// TenancyRelation
// Deprecated: use the string version instead
type TenancyRelation struct {
Child uuid.UUID
Parent uuid.UUID
}
type TenancyRelationWithStrId struct {
ChildId string
ParentId string
}
type MockTenancyAccessor struct {
ParentLookup map[string]string
ChildrenLookup map[string][]string
DescendantsLookup map[string][]string
AncestorsLookup map[string][]string
Root string
Isloaded bool
}
// NewMockTenancyAccessor
// Deprecated: Use string version instead
func NewMockTenancyAccessor(tenantRelations []TenancyRelation, root uuid.UUID) *MockTenancyAccessor {
m := &MockTenancyAccessor{}
// default
m.Isloaded = true
m.Reset(tenantRelations, root)
return m
}
func NewMockTenancyAccessorUsingStrIds(tenantRelations []TenancyRelationWithStrId, root string) *MockTenancyAccessor {
m := &MockTenancyAccessor{}
m.Isloaded = true
m.ResetWithStrIds(tenantRelations, root)
return m
}
// Reset
// Deprecated: Use the str version instead
func (m *MockTenancyAccessor) Reset(tenantRelations []TenancyRelation, root uuid.UUID) {
var trWithStrId []TenancyRelationWithStrId
for _, tr := range tenantRelations {
trWithStrId = append(trWithStrId, TenancyRelationWithStrId{
ChildId: tr.Child.String(),
ParentId: tr.Parent.String(),
})
}
rootStrId := root.String()
m.ResetWithStrIds(trWithStrId, rootStrId)
}
func (m *MockTenancyAccessor) ResetWithStrIds(tenantRelations []TenancyRelationWithStrId, root string) {
m.ParentLookup = make(map[string]string)
m.ChildrenLookup = make(map[string][]string)
m.DescendantsLookup = make(map[string][]string)
m.AncestorsLookup = make(map[string][]string)
m.Root = root
//build the parent and children lookup
for _, r := range tenantRelations {
m.ParentLookup[r.ChildId] = r.ParentId
children := m.ChildrenLookup[r.ParentId]
children = append(children, r.ChildId)
m.ChildrenLookup[r.ParentId] = children
}
//build the ancestor lookup
for child, _ := range m.ParentLookup {
var ancestors []string
tenantId := child
for {
parent, ok := m.ParentLookup[tenantId]
if ok {
ancestors = append(ancestors, parent)
tenantId = parent
} else {
break
}
}
m.AncestorsLookup[child] = ancestors
}
//build the descendant lookup
for parent, _ := range m.ChildrenLookup {
var descendants []string
idsToVisit := list.New()
idsToVisit.PushBack(parent)
for idsToVisit.Len() != 0 {
id := idsToVisit.Front()
idsToVisit.Remove(id)
if children, ok := m.ChildrenLookup[id.Value.(string)]; ok {
for _, c := range children {
idsToVisit.PushBack(c)
}
descendants = append(descendants, children...)
}
}
m.DescendantsLookup[parent] = descendants
}
}
func (m *MockTenancyAccessor) GetParent(ctx context.Context, tenantId string) (string, error) {
if parent, ok := m.ParentLookup[tenantId]; ok {
return parent, nil
} else {
return "", errors.New("parent not found")
}
}
func (m *MockTenancyAccessor) GetChildren(ctx context.Context, tenantId string) ([]string, error) {
if children, ok := m.ChildrenLookup[tenantId]; ok {
return children, nil
} else {
return nil, errors.New("children not found")
}
}
func (m *MockTenancyAccessor) GetAncestors(ctx context.Context, tenantId string) ([]string, error) {
if tenantId == m.Root {
return make([]string, 0), nil
}
if ancestors, ok := m.AncestorsLookup[tenantId]; ok {
return ancestors, nil
} else {
return nil, errors.New("ancestors not found")
}
}
func (m *MockTenancyAccessor) GetDescendants(ctx context.Context, tenantId string) ([]string, error) {
if descendants, ok := m.DescendantsLookup[tenantId]; ok {
return descendants, nil
} else {
return nil, errors.New("descendants not found")
}
}
func (m *MockTenancyAccessor) GetRoot(ctx context.Context) (string, error) {
if m.Root != "" {
return m.Root, nil
} else {
return "", errors.New("root not set")
}
}
func (m *MockTenancyAccessor) IsLoaded(ctx context.Context) bool {
return m.Isloaded
}
func (m *MockTenancyAccessor) GetTenancyPath(ctx context.Context, tenantId string) ([]uuid.UUID, error) {
current, err := uuid.Parse(tenantId)
if err != nil {
return nil, err
}
path := []uuid.UUID{current}
ancestors, err := m.GetAncestors(ctx, tenantId)
if err != nil {
return nil, err
}
for _, str := range ancestors {
id, err := uuid.Parse(str)
if err != nil {
return nil, err
}
path = append(path, id)
}
//reverse the order to that the result is root tenant id -> current tenant id
//fi is index going forward starting from 0,
//ri is index going backward starting from last element
//swap the element at ri and ri
for fi, ri := 0, len(path)-1; fi < ri; fi, ri = fi+1, ri-1 {
path[fi], path[ri] = path[ri], path[fi]
}
return path, nil
}
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/go-redis/redis/v8 (interfaces: UniversalClient)
// Package redismock is a generated GoMock package.
package redismock
import (
context "context"
reflect "reflect"
time "time"
redis "github.com/go-redis/redis/v8"
gomock "github.com/golang/mock/gomock"
)
// MockUniversalClient is a mock of UniversalClient interface.
type MockUniversalClient struct {
ctrl *gomock.Controller
recorder *MockUniversalClientMockRecorder
}
// MockUniversalClientMockRecorder is the mock recorder for MockUniversalClient.
type MockUniversalClientMockRecorder struct {
mock *MockUniversalClient
}
// NewMockUniversalClient creates a new mock instance.
func NewMockUniversalClient(ctrl *gomock.Controller) *MockUniversalClient {
mock := &MockUniversalClient{ctrl: ctrl}
mock.recorder = &MockUniversalClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockUniversalClient) EXPECT() *MockUniversalClientMockRecorder {
return m.recorder
}
// AddHook mocks base method.
func (m *MockUniversalClient) AddHook(arg0 redis.Hook) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddHook", arg0)
}
// AddHook indicates an expected call of AddHook.
func (mr *MockUniversalClientMockRecorder) AddHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddHook", reflect.TypeOf((*MockUniversalClient)(nil).AddHook), arg0)
}
// Append mocks base method.
func (m *MockUniversalClient) Append(arg0 context.Context, arg1, arg2 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Append", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Append indicates an expected call of Append.
func (mr *MockUniversalClientMockRecorder) Append(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Append", reflect.TypeOf((*MockUniversalClient)(nil).Append), arg0, arg1, arg2)
}
// BLMove mocks base method.
func (m *MockUniversalClient) BLMove(arg0 context.Context, arg1, arg2, arg3, arg4 string, arg5 time.Duration) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BLMove", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// BLMove indicates an expected call of BLMove.
func (mr *MockUniversalClientMockRecorder) BLMove(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BLMove", reflect.TypeOf((*MockUniversalClient)(nil).BLMove), arg0, arg1, arg2, arg3, arg4, arg5)
}
// BLPop mocks base method.
func (m *MockUniversalClient) BLPop(arg0 context.Context, arg1 time.Duration, arg2 ...string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BLPop", varargs...)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// BLPop indicates an expected call of BLPop.
func (mr *MockUniversalClientMockRecorder) BLPop(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BLPop", reflect.TypeOf((*MockUniversalClient)(nil).BLPop), varargs...)
}
// BRPop mocks base method.
func (m *MockUniversalClient) BRPop(arg0 context.Context, arg1 time.Duration, arg2 ...string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BRPop", varargs...)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// BRPop indicates an expected call of BRPop.
func (mr *MockUniversalClientMockRecorder) BRPop(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BRPop", reflect.TypeOf((*MockUniversalClient)(nil).BRPop), varargs...)
}
// BRPopLPush mocks base method.
func (m *MockUniversalClient) BRPopLPush(arg0 context.Context, arg1, arg2 string, arg3 time.Duration) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BRPopLPush", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// BRPopLPush indicates an expected call of BRPopLPush.
func (mr *MockUniversalClientMockRecorder) BRPopLPush(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BRPopLPush", reflect.TypeOf((*MockUniversalClient)(nil).BRPopLPush), arg0, arg1, arg2, arg3)
}
// BZPopMax mocks base method.
func (m *MockUniversalClient) BZPopMax(arg0 context.Context, arg1 time.Duration, arg2 ...string) *redis.ZWithKeyCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BZPopMax", varargs...)
ret0, _ := ret[0].(*redis.ZWithKeyCmd)
return ret0
}
// BZPopMax indicates an expected call of BZPopMax.
func (mr *MockUniversalClientMockRecorder) BZPopMax(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BZPopMax", reflect.TypeOf((*MockUniversalClient)(nil).BZPopMax), varargs...)
}
// BZPopMin mocks base method.
func (m *MockUniversalClient) BZPopMin(arg0 context.Context, arg1 time.Duration, arg2 ...string) *redis.ZWithKeyCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BZPopMin", varargs...)
ret0, _ := ret[0].(*redis.ZWithKeyCmd)
return ret0
}
// BZPopMin indicates an expected call of BZPopMin.
func (mr *MockUniversalClientMockRecorder) BZPopMin(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BZPopMin", reflect.TypeOf((*MockUniversalClient)(nil).BZPopMin), varargs...)
}
// BgRewriteAOF mocks base method.
func (m *MockUniversalClient) BgRewriteAOF(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BgRewriteAOF", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// BgRewriteAOF indicates an expected call of BgRewriteAOF.
func (mr *MockUniversalClientMockRecorder) BgRewriteAOF(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BgRewriteAOF", reflect.TypeOf((*MockUniversalClient)(nil).BgRewriteAOF), arg0)
}
// BgSave mocks base method.
func (m *MockUniversalClient) BgSave(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BgSave", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// BgSave indicates an expected call of BgSave.
func (mr *MockUniversalClientMockRecorder) BgSave(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BgSave", reflect.TypeOf((*MockUniversalClient)(nil).BgSave), arg0)
}
// BitCount mocks base method.
func (m *MockUniversalClient) BitCount(arg0 context.Context, arg1 string, arg2 *redis.BitCount) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BitCount", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// BitCount indicates an expected call of BitCount.
func (mr *MockUniversalClientMockRecorder) BitCount(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BitCount", reflect.TypeOf((*MockUniversalClient)(nil).BitCount), arg0, arg1, arg2)
}
// BitField mocks base method.
func (m *MockUniversalClient) BitField(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BitField", varargs...)
ret0, _ := ret[0].(*redis.IntSliceCmd)
return ret0
}
// BitField indicates an expected call of BitField.
func (mr *MockUniversalClientMockRecorder) BitField(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BitField", reflect.TypeOf((*MockUniversalClient)(nil).BitField), varargs...)
}
// BitOpAnd mocks base method.
func (m *MockUniversalClient) BitOpAnd(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BitOpAnd", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// BitOpAnd indicates an expected call of BitOpAnd.
func (mr *MockUniversalClientMockRecorder) BitOpAnd(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BitOpAnd", reflect.TypeOf((*MockUniversalClient)(nil).BitOpAnd), varargs...)
}
// BitOpNot mocks base method.
func (m *MockUniversalClient) BitOpNot(arg0 context.Context, arg1, arg2 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BitOpNot", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// BitOpNot indicates an expected call of BitOpNot.
func (mr *MockUniversalClientMockRecorder) BitOpNot(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BitOpNot", reflect.TypeOf((*MockUniversalClient)(nil).BitOpNot), arg0, arg1, arg2)
}
// BitOpOr mocks base method.
func (m *MockUniversalClient) BitOpOr(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BitOpOr", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// BitOpOr indicates an expected call of BitOpOr.
func (mr *MockUniversalClientMockRecorder) BitOpOr(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BitOpOr", reflect.TypeOf((*MockUniversalClient)(nil).BitOpOr), varargs...)
}
// BitOpXor mocks base method.
func (m *MockUniversalClient) BitOpXor(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BitOpXor", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// BitOpXor indicates an expected call of BitOpXor.
func (mr *MockUniversalClientMockRecorder) BitOpXor(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BitOpXor", reflect.TypeOf((*MockUniversalClient)(nil).BitOpXor), varargs...)
}
// BitPos mocks base method.
func (m *MockUniversalClient) BitPos(arg0 context.Context, arg1 string, arg2 int64, arg3 ...int64) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "BitPos", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// BitPos indicates an expected call of BitPos.
func (mr *MockUniversalClientMockRecorder) BitPos(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BitPos", reflect.TypeOf((*MockUniversalClient)(nil).BitPos), varargs...)
}
// ClientGetName mocks base method.
func (m *MockUniversalClient) ClientGetName(arg0 context.Context) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClientGetName", arg0)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// ClientGetName indicates an expected call of ClientGetName.
func (mr *MockUniversalClientMockRecorder) ClientGetName(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientGetName", reflect.TypeOf((*MockUniversalClient)(nil).ClientGetName), arg0)
}
// ClientID mocks base method.
func (m *MockUniversalClient) ClientID(arg0 context.Context) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClientID", arg0)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ClientID indicates an expected call of ClientID.
func (mr *MockUniversalClientMockRecorder) ClientID(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientID", reflect.TypeOf((*MockUniversalClient)(nil).ClientID), arg0)
}
// ClientKill mocks base method.
func (m *MockUniversalClient) ClientKill(arg0 context.Context, arg1 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClientKill", arg0, arg1)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClientKill indicates an expected call of ClientKill.
func (mr *MockUniversalClientMockRecorder) ClientKill(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientKill", reflect.TypeOf((*MockUniversalClient)(nil).ClientKill), arg0, arg1)
}
// ClientKillByFilter mocks base method.
func (m *MockUniversalClient) ClientKillByFilter(arg0 context.Context, arg1 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ClientKillByFilter", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ClientKillByFilter indicates an expected call of ClientKillByFilter.
func (mr *MockUniversalClientMockRecorder) ClientKillByFilter(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientKillByFilter", reflect.TypeOf((*MockUniversalClient)(nil).ClientKillByFilter), varargs...)
}
// ClientList mocks base method.
func (m *MockUniversalClient) ClientList(arg0 context.Context) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClientList", arg0)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// ClientList indicates an expected call of ClientList.
func (mr *MockUniversalClientMockRecorder) ClientList(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientList", reflect.TypeOf((*MockUniversalClient)(nil).ClientList), arg0)
}
// ClientPause mocks base method.
func (m *MockUniversalClient) ClientPause(arg0 context.Context, arg1 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClientPause", arg0, arg1)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// ClientPause indicates an expected call of ClientPause.
func (mr *MockUniversalClientMockRecorder) ClientPause(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientPause", reflect.TypeOf((*MockUniversalClient)(nil).ClientPause), arg0, arg1)
}
// Close mocks base method.
func (m *MockUniversalClient) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockUniversalClientMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockUniversalClient)(nil).Close))
}
// ClusterAddSlots mocks base method.
func (m *MockUniversalClient) ClusterAddSlots(arg0 context.Context, arg1 ...int) *redis.StatusCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ClusterAddSlots", varargs...)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterAddSlots indicates an expected call of ClusterAddSlots.
func (mr *MockUniversalClientMockRecorder) ClusterAddSlots(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterAddSlots", reflect.TypeOf((*MockUniversalClient)(nil).ClusterAddSlots), varargs...)
}
// ClusterAddSlotsRange mocks base method.
func (m *MockUniversalClient) ClusterAddSlotsRange(arg0 context.Context, arg1, arg2 int) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterAddSlotsRange", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterAddSlotsRange indicates an expected call of ClusterAddSlotsRange.
func (mr *MockUniversalClientMockRecorder) ClusterAddSlotsRange(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterAddSlotsRange", reflect.TypeOf((*MockUniversalClient)(nil).ClusterAddSlotsRange), arg0, arg1, arg2)
}
// ClusterCountFailureReports mocks base method.
func (m *MockUniversalClient) ClusterCountFailureReports(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterCountFailureReports", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ClusterCountFailureReports indicates an expected call of ClusterCountFailureReports.
func (mr *MockUniversalClientMockRecorder) ClusterCountFailureReports(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterCountFailureReports", reflect.TypeOf((*MockUniversalClient)(nil).ClusterCountFailureReports), arg0, arg1)
}
// ClusterCountKeysInSlot mocks base method.
func (m *MockUniversalClient) ClusterCountKeysInSlot(arg0 context.Context, arg1 int) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterCountKeysInSlot", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ClusterCountKeysInSlot indicates an expected call of ClusterCountKeysInSlot.
func (mr *MockUniversalClientMockRecorder) ClusterCountKeysInSlot(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterCountKeysInSlot", reflect.TypeOf((*MockUniversalClient)(nil).ClusterCountKeysInSlot), arg0, arg1)
}
// ClusterDelSlots mocks base method.
func (m *MockUniversalClient) ClusterDelSlots(arg0 context.Context, arg1 ...int) *redis.StatusCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ClusterDelSlots", varargs...)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterDelSlots indicates an expected call of ClusterDelSlots.
func (mr *MockUniversalClientMockRecorder) ClusterDelSlots(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterDelSlots", reflect.TypeOf((*MockUniversalClient)(nil).ClusterDelSlots), varargs...)
}
// ClusterDelSlotsRange mocks base method.
func (m *MockUniversalClient) ClusterDelSlotsRange(arg0 context.Context, arg1, arg2 int) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterDelSlotsRange", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterDelSlotsRange indicates an expected call of ClusterDelSlotsRange.
func (mr *MockUniversalClientMockRecorder) ClusterDelSlotsRange(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterDelSlotsRange", reflect.TypeOf((*MockUniversalClient)(nil).ClusterDelSlotsRange), arg0, arg1, arg2)
}
// ClusterFailover mocks base method.
func (m *MockUniversalClient) ClusterFailover(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterFailover", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterFailover indicates an expected call of ClusterFailover.
func (mr *MockUniversalClientMockRecorder) ClusterFailover(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterFailover", reflect.TypeOf((*MockUniversalClient)(nil).ClusterFailover), arg0)
}
// ClusterForget mocks base method.
func (m *MockUniversalClient) ClusterForget(arg0 context.Context, arg1 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterForget", arg0, arg1)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterForget indicates an expected call of ClusterForget.
func (mr *MockUniversalClientMockRecorder) ClusterForget(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterForget", reflect.TypeOf((*MockUniversalClient)(nil).ClusterForget), arg0, arg1)
}
// ClusterGetKeysInSlot mocks base method.
func (m *MockUniversalClient) ClusterGetKeysInSlot(arg0 context.Context, arg1, arg2 int) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterGetKeysInSlot", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ClusterGetKeysInSlot indicates an expected call of ClusterGetKeysInSlot.
func (mr *MockUniversalClientMockRecorder) ClusterGetKeysInSlot(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterGetKeysInSlot", reflect.TypeOf((*MockUniversalClient)(nil).ClusterGetKeysInSlot), arg0, arg1, arg2)
}
// ClusterInfo mocks base method.
func (m *MockUniversalClient) ClusterInfo(arg0 context.Context) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterInfo", arg0)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// ClusterInfo indicates an expected call of ClusterInfo.
func (mr *MockUniversalClientMockRecorder) ClusterInfo(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterInfo", reflect.TypeOf((*MockUniversalClient)(nil).ClusterInfo), arg0)
}
// ClusterKeySlot mocks base method.
func (m *MockUniversalClient) ClusterKeySlot(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterKeySlot", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ClusterKeySlot indicates an expected call of ClusterKeySlot.
func (mr *MockUniversalClientMockRecorder) ClusterKeySlot(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterKeySlot", reflect.TypeOf((*MockUniversalClient)(nil).ClusterKeySlot), arg0, arg1)
}
// ClusterMeet mocks base method.
func (m *MockUniversalClient) ClusterMeet(arg0 context.Context, arg1, arg2 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterMeet", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterMeet indicates an expected call of ClusterMeet.
func (mr *MockUniversalClientMockRecorder) ClusterMeet(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterMeet", reflect.TypeOf((*MockUniversalClient)(nil).ClusterMeet), arg0, arg1, arg2)
}
// ClusterNodes mocks base method.
func (m *MockUniversalClient) ClusterNodes(arg0 context.Context) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterNodes", arg0)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// ClusterNodes indicates an expected call of ClusterNodes.
func (mr *MockUniversalClientMockRecorder) ClusterNodes(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterNodes", reflect.TypeOf((*MockUniversalClient)(nil).ClusterNodes), arg0)
}
// ClusterReplicate mocks base method.
func (m *MockUniversalClient) ClusterReplicate(arg0 context.Context, arg1 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterReplicate", arg0, arg1)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterReplicate indicates an expected call of ClusterReplicate.
func (mr *MockUniversalClientMockRecorder) ClusterReplicate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterReplicate", reflect.TypeOf((*MockUniversalClient)(nil).ClusterReplicate), arg0, arg1)
}
// ClusterResetHard mocks base method.
func (m *MockUniversalClient) ClusterResetHard(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterResetHard", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterResetHard indicates an expected call of ClusterResetHard.
func (mr *MockUniversalClientMockRecorder) ClusterResetHard(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterResetHard", reflect.TypeOf((*MockUniversalClient)(nil).ClusterResetHard), arg0)
}
// ClusterResetSoft mocks base method.
func (m *MockUniversalClient) ClusterResetSoft(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterResetSoft", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterResetSoft indicates an expected call of ClusterResetSoft.
func (mr *MockUniversalClientMockRecorder) ClusterResetSoft(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterResetSoft", reflect.TypeOf((*MockUniversalClient)(nil).ClusterResetSoft), arg0)
}
// ClusterSaveConfig mocks base method.
func (m *MockUniversalClient) ClusterSaveConfig(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSaveConfig", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ClusterSaveConfig indicates an expected call of ClusterSaveConfig.
func (mr *MockUniversalClientMockRecorder) ClusterSaveConfig(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSaveConfig", reflect.TypeOf((*MockUniversalClient)(nil).ClusterSaveConfig), arg0)
}
// ClusterSlaves mocks base method.
func (m *MockUniversalClient) ClusterSlaves(arg0 context.Context, arg1 string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSlaves", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ClusterSlaves indicates an expected call of ClusterSlaves.
func (mr *MockUniversalClientMockRecorder) ClusterSlaves(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSlaves", reflect.TypeOf((*MockUniversalClient)(nil).ClusterSlaves), arg0, arg1)
}
// ClusterSlots mocks base method.
func (m *MockUniversalClient) ClusterSlots(arg0 context.Context) *redis.ClusterSlotsCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSlots", arg0)
ret0, _ := ret[0].(*redis.ClusterSlotsCmd)
return ret0
}
// ClusterSlots indicates an expected call of ClusterSlots.
func (mr *MockUniversalClientMockRecorder) ClusterSlots(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSlots", reflect.TypeOf((*MockUniversalClient)(nil).ClusterSlots), arg0)
}
// Command mocks base method.
func (m *MockUniversalClient) Command(arg0 context.Context) *redis.CommandsInfoCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Command", arg0)
ret0, _ := ret[0].(*redis.CommandsInfoCmd)
return ret0
}
// Command indicates an expected call of Command.
func (mr *MockUniversalClientMockRecorder) Command(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Command", reflect.TypeOf((*MockUniversalClient)(nil).Command), arg0)
}
// ConfigGet mocks base method.
func (m *MockUniversalClient) ConfigGet(arg0 context.Context, arg1 string) *redis.SliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConfigGet", arg0, arg1)
ret0, _ := ret[0].(*redis.SliceCmd)
return ret0
}
// ConfigGet indicates an expected call of ConfigGet.
func (mr *MockUniversalClientMockRecorder) ConfigGet(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConfigGet", reflect.TypeOf((*MockUniversalClient)(nil).ConfigGet), arg0, arg1)
}
// ConfigResetStat mocks base method.
func (m *MockUniversalClient) ConfigResetStat(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConfigResetStat", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ConfigResetStat indicates an expected call of ConfigResetStat.
func (mr *MockUniversalClientMockRecorder) ConfigResetStat(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConfigResetStat", reflect.TypeOf((*MockUniversalClient)(nil).ConfigResetStat), arg0)
}
// ConfigRewrite mocks base method.
func (m *MockUniversalClient) ConfigRewrite(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConfigRewrite", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ConfigRewrite indicates an expected call of ConfigRewrite.
func (mr *MockUniversalClientMockRecorder) ConfigRewrite(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConfigRewrite", reflect.TypeOf((*MockUniversalClient)(nil).ConfigRewrite), arg0)
}
// ConfigSet mocks base method.
func (m *MockUniversalClient) ConfigSet(arg0 context.Context, arg1, arg2 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ConfigSet", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ConfigSet indicates an expected call of ConfigSet.
func (mr *MockUniversalClientMockRecorder) ConfigSet(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConfigSet", reflect.TypeOf((*MockUniversalClient)(nil).ConfigSet), arg0, arg1, arg2)
}
// Context mocks base method.
func (m *MockUniversalClient) Context() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context.
func (mr *MockUniversalClientMockRecorder) Context() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockUniversalClient)(nil).Context))
}
// Copy mocks base method.
func (m *MockUniversalClient) Copy(arg0 context.Context, arg1, arg2 string, arg3 int, arg4 bool) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Copy", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Copy indicates an expected call of Copy.
func (mr *MockUniversalClientMockRecorder) Copy(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Copy", reflect.TypeOf((*MockUniversalClient)(nil).Copy), arg0, arg1, arg2, arg3, arg4)
}
// DBSize mocks base method.
func (m *MockUniversalClient) DBSize(arg0 context.Context) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DBSize", arg0)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// DBSize indicates an expected call of DBSize.
func (mr *MockUniversalClientMockRecorder) DBSize(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DBSize", reflect.TypeOf((*MockUniversalClient)(nil).DBSize), arg0)
}
// DebugObject mocks base method.
func (m *MockUniversalClient) DebugObject(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DebugObject", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// DebugObject indicates an expected call of DebugObject.
func (mr *MockUniversalClientMockRecorder) DebugObject(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugObject", reflect.TypeOf((*MockUniversalClient)(nil).DebugObject), arg0, arg1)
}
// Decr mocks base method.
func (m *MockUniversalClient) Decr(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Decr", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Decr indicates an expected call of Decr.
func (mr *MockUniversalClientMockRecorder) Decr(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decr", reflect.TypeOf((*MockUniversalClient)(nil).Decr), arg0, arg1)
}
// DecrBy mocks base method.
func (m *MockUniversalClient) DecrBy(arg0 context.Context, arg1 string, arg2 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DecrBy", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// DecrBy indicates an expected call of DecrBy.
func (mr *MockUniversalClientMockRecorder) DecrBy(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrBy", reflect.TypeOf((*MockUniversalClient)(nil).DecrBy), arg0, arg1, arg2)
}
// Del mocks base method.
func (m *MockUniversalClient) Del(arg0 context.Context, arg1 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Del", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Del indicates an expected call of Del.
func (mr *MockUniversalClientMockRecorder) Del(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Del", reflect.TypeOf((*MockUniversalClient)(nil).Del), varargs...)
}
// Do mocks base method.
func (m *MockUniversalClient) Do(arg0 context.Context, arg1 ...interface{}) *redis.Cmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Do", varargs...)
ret0, _ := ret[0].(*redis.Cmd)
return ret0
}
// Do indicates an expected call of Do.
func (mr *MockUniversalClientMockRecorder) Do(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockUniversalClient)(nil).Do), varargs...)
}
// Dump mocks base method.
func (m *MockUniversalClient) Dump(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Dump", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// Dump indicates an expected call of Dump.
func (mr *MockUniversalClientMockRecorder) Dump(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dump", reflect.TypeOf((*MockUniversalClient)(nil).Dump), arg0, arg1)
}
// Echo mocks base method.
func (m *MockUniversalClient) Echo(arg0 context.Context, arg1 interface{}) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Echo", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// Echo indicates an expected call of Echo.
func (mr *MockUniversalClientMockRecorder) Echo(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Echo", reflect.TypeOf((*MockUniversalClient)(nil).Echo), arg0, arg1)
}
// Eval mocks base method.
func (m *MockUniversalClient) Eval(arg0 context.Context, arg1 string, arg2 []string, arg3 ...interface{}) *redis.Cmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Eval", varargs...)
ret0, _ := ret[0].(*redis.Cmd)
return ret0
}
// Eval indicates an expected call of Eval.
func (mr *MockUniversalClientMockRecorder) Eval(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Eval", reflect.TypeOf((*MockUniversalClient)(nil).Eval), varargs...)
}
// EvalSha mocks base method.
func (m *MockUniversalClient) EvalSha(arg0 context.Context, arg1 string, arg2 []string, arg3 ...interface{}) *redis.Cmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "EvalSha", varargs...)
ret0, _ := ret[0].(*redis.Cmd)
return ret0
}
// EvalSha indicates an expected call of EvalSha.
func (mr *MockUniversalClientMockRecorder) EvalSha(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EvalSha", reflect.TypeOf((*MockUniversalClient)(nil).EvalSha), varargs...)
}
// Exists mocks base method.
func (m *MockUniversalClient) Exists(arg0 context.Context, arg1 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exists", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Exists indicates an expected call of Exists.
func (mr *MockUniversalClientMockRecorder) Exists(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockUniversalClient)(nil).Exists), varargs...)
}
// Expire mocks base method.
func (m *MockUniversalClient) Expire(arg0 context.Context, arg1 string, arg2 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Expire", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// Expire indicates an expected call of Expire.
func (mr *MockUniversalClientMockRecorder) Expire(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Expire", reflect.TypeOf((*MockUniversalClient)(nil).Expire), arg0, arg1, arg2)
}
// ExpireAt mocks base method.
func (m *MockUniversalClient) ExpireAt(arg0 context.Context, arg1 string, arg2 time.Time) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExpireAt", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// ExpireAt indicates an expected call of ExpireAt.
func (mr *MockUniversalClientMockRecorder) ExpireAt(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpireAt", reflect.TypeOf((*MockUniversalClient)(nil).ExpireAt), arg0, arg1, arg2)
}
// ExpireGT mocks base method.
func (m *MockUniversalClient) ExpireGT(arg0 context.Context, arg1 string, arg2 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExpireGT", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// ExpireGT indicates an expected call of ExpireGT.
func (mr *MockUniversalClientMockRecorder) ExpireGT(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpireGT", reflect.TypeOf((*MockUniversalClient)(nil).ExpireGT), arg0, arg1, arg2)
}
// ExpireLT mocks base method.
func (m *MockUniversalClient) ExpireLT(arg0 context.Context, arg1 string, arg2 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExpireLT", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// ExpireLT indicates an expected call of ExpireLT.
func (mr *MockUniversalClientMockRecorder) ExpireLT(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpireLT", reflect.TypeOf((*MockUniversalClient)(nil).ExpireLT), arg0, arg1, arg2)
}
// ExpireNX mocks base method.
func (m *MockUniversalClient) ExpireNX(arg0 context.Context, arg1 string, arg2 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExpireNX", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// ExpireNX indicates an expected call of ExpireNX.
func (mr *MockUniversalClientMockRecorder) ExpireNX(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpireNX", reflect.TypeOf((*MockUniversalClient)(nil).ExpireNX), arg0, arg1, arg2)
}
// ExpireXX mocks base method.
func (m *MockUniversalClient) ExpireXX(arg0 context.Context, arg1 string, arg2 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExpireXX", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// ExpireXX indicates an expected call of ExpireXX.
func (mr *MockUniversalClientMockRecorder) ExpireXX(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpireXX", reflect.TypeOf((*MockUniversalClient)(nil).ExpireXX), arg0, arg1, arg2)
}
// FlushAll mocks base method.
func (m *MockUniversalClient) FlushAll(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FlushAll", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// FlushAll indicates an expected call of FlushAll.
func (mr *MockUniversalClientMockRecorder) FlushAll(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushAll", reflect.TypeOf((*MockUniversalClient)(nil).FlushAll), arg0)
}
// FlushAllAsync mocks base method.
func (m *MockUniversalClient) FlushAllAsync(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FlushAllAsync", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// FlushAllAsync indicates an expected call of FlushAllAsync.
func (mr *MockUniversalClientMockRecorder) FlushAllAsync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushAllAsync", reflect.TypeOf((*MockUniversalClient)(nil).FlushAllAsync), arg0)
}
// FlushDB mocks base method.
func (m *MockUniversalClient) FlushDB(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FlushDB", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// FlushDB indicates an expected call of FlushDB.
func (mr *MockUniversalClientMockRecorder) FlushDB(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushDB", reflect.TypeOf((*MockUniversalClient)(nil).FlushDB), arg0)
}
// FlushDBAsync mocks base method.
func (m *MockUniversalClient) FlushDBAsync(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FlushDBAsync", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// FlushDBAsync indicates an expected call of FlushDBAsync.
func (mr *MockUniversalClientMockRecorder) FlushDBAsync(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushDBAsync", reflect.TypeOf((*MockUniversalClient)(nil).FlushDBAsync), arg0)
}
// GeoAdd mocks base method.
func (m *MockUniversalClient) GeoAdd(arg0 context.Context, arg1 string, arg2 ...*redis.GeoLocation) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "GeoAdd", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// GeoAdd indicates an expected call of GeoAdd.
func (mr *MockUniversalClientMockRecorder) GeoAdd(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoAdd", reflect.TypeOf((*MockUniversalClient)(nil).GeoAdd), varargs...)
}
// GeoDist mocks base method.
func (m *MockUniversalClient) GeoDist(arg0 context.Context, arg1, arg2, arg3, arg4 string) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GeoDist", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// GeoDist indicates an expected call of GeoDist.
func (mr *MockUniversalClientMockRecorder) GeoDist(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoDist", reflect.TypeOf((*MockUniversalClient)(nil).GeoDist), arg0, arg1, arg2, arg3, arg4)
}
// GeoHash mocks base method.
func (m *MockUniversalClient) GeoHash(arg0 context.Context, arg1 string, arg2 ...string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "GeoHash", varargs...)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// GeoHash indicates an expected call of GeoHash.
func (mr *MockUniversalClientMockRecorder) GeoHash(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoHash", reflect.TypeOf((*MockUniversalClient)(nil).GeoHash), varargs...)
}
// GeoPos mocks base method.
func (m *MockUniversalClient) GeoPos(arg0 context.Context, arg1 string, arg2 ...string) *redis.GeoPosCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "GeoPos", varargs...)
ret0, _ := ret[0].(*redis.GeoPosCmd)
return ret0
}
// GeoPos indicates an expected call of GeoPos.
func (mr *MockUniversalClientMockRecorder) GeoPos(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoPos", reflect.TypeOf((*MockUniversalClient)(nil).GeoPos), varargs...)
}
// GeoRadius mocks base method.
func (m *MockUniversalClient) GeoRadius(arg0 context.Context, arg1 string, arg2, arg3 float64, arg4 *redis.GeoRadiusQuery) *redis.GeoLocationCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GeoRadius", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.GeoLocationCmd)
return ret0
}
// GeoRadius indicates an expected call of GeoRadius.
func (mr *MockUniversalClientMockRecorder) GeoRadius(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoRadius", reflect.TypeOf((*MockUniversalClient)(nil).GeoRadius), arg0, arg1, arg2, arg3, arg4)
}
// GeoRadiusByMember mocks base method.
func (m *MockUniversalClient) GeoRadiusByMember(arg0 context.Context, arg1, arg2 string, arg3 *redis.GeoRadiusQuery) *redis.GeoLocationCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GeoRadiusByMember", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.GeoLocationCmd)
return ret0
}
// GeoRadiusByMember indicates an expected call of GeoRadiusByMember.
func (mr *MockUniversalClientMockRecorder) GeoRadiusByMember(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoRadiusByMember", reflect.TypeOf((*MockUniversalClient)(nil).GeoRadiusByMember), arg0, arg1, arg2, arg3)
}
// GeoRadiusByMemberStore mocks base method.
func (m *MockUniversalClient) GeoRadiusByMemberStore(arg0 context.Context, arg1, arg2 string, arg3 *redis.GeoRadiusQuery) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GeoRadiusByMemberStore", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// GeoRadiusByMemberStore indicates an expected call of GeoRadiusByMemberStore.
func (mr *MockUniversalClientMockRecorder) GeoRadiusByMemberStore(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoRadiusByMemberStore", reflect.TypeOf((*MockUniversalClient)(nil).GeoRadiusByMemberStore), arg0, arg1, arg2, arg3)
}
// GeoRadiusStore mocks base method.
func (m *MockUniversalClient) GeoRadiusStore(arg0 context.Context, arg1 string, arg2, arg3 float64, arg4 *redis.GeoRadiusQuery) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GeoRadiusStore", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// GeoRadiusStore indicates an expected call of GeoRadiusStore.
func (mr *MockUniversalClientMockRecorder) GeoRadiusStore(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoRadiusStore", reflect.TypeOf((*MockUniversalClient)(nil).GeoRadiusStore), arg0, arg1, arg2, arg3, arg4)
}
// GeoSearch mocks base method.
func (m *MockUniversalClient) GeoSearch(arg0 context.Context, arg1 string, arg2 *redis.GeoSearchQuery) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GeoSearch", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// GeoSearch indicates an expected call of GeoSearch.
func (mr *MockUniversalClientMockRecorder) GeoSearch(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoSearch", reflect.TypeOf((*MockUniversalClient)(nil).GeoSearch), arg0, arg1, arg2)
}
// GeoSearchLocation mocks base method.
func (m *MockUniversalClient) GeoSearchLocation(arg0 context.Context, arg1 string, arg2 *redis.GeoSearchLocationQuery) *redis.GeoSearchLocationCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GeoSearchLocation", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.GeoSearchLocationCmd)
return ret0
}
// GeoSearchLocation indicates an expected call of GeoSearchLocation.
func (mr *MockUniversalClientMockRecorder) GeoSearchLocation(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoSearchLocation", reflect.TypeOf((*MockUniversalClient)(nil).GeoSearchLocation), arg0, arg1, arg2)
}
// GeoSearchStore mocks base method.
func (m *MockUniversalClient) GeoSearchStore(arg0 context.Context, arg1, arg2 string, arg3 *redis.GeoSearchStoreQuery) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GeoSearchStore", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// GeoSearchStore indicates an expected call of GeoSearchStore.
func (mr *MockUniversalClientMockRecorder) GeoSearchStore(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeoSearchStore", reflect.TypeOf((*MockUniversalClient)(nil).GeoSearchStore), arg0, arg1, arg2, arg3)
}
// Get mocks base method.
func (m *MockUniversalClient) Get(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// Get indicates an expected call of Get.
func (mr *MockUniversalClientMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockUniversalClient)(nil).Get), arg0, arg1)
}
// GetBit mocks base method.
func (m *MockUniversalClient) GetBit(arg0 context.Context, arg1 string, arg2 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBit", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// GetBit indicates an expected call of GetBit.
func (mr *MockUniversalClientMockRecorder) GetBit(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBit", reflect.TypeOf((*MockUniversalClient)(nil).GetBit), arg0, arg1, arg2)
}
// GetDel mocks base method.
func (m *MockUniversalClient) GetDel(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDel", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// GetDel indicates an expected call of GetDel.
func (mr *MockUniversalClientMockRecorder) GetDel(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDel", reflect.TypeOf((*MockUniversalClient)(nil).GetDel), arg0, arg1)
}
// GetEx mocks base method.
func (m *MockUniversalClient) GetEx(arg0 context.Context, arg1 string, arg2 time.Duration) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEx", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// GetEx indicates an expected call of GetEx.
func (mr *MockUniversalClientMockRecorder) GetEx(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEx", reflect.TypeOf((*MockUniversalClient)(nil).GetEx), arg0, arg1, arg2)
}
// GetRange mocks base method.
func (m *MockUniversalClient) GetRange(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRange", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// GetRange indicates an expected call of GetRange.
func (mr *MockUniversalClientMockRecorder) GetRange(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRange", reflect.TypeOf((*MockUniversalClient)(nil).GetRange), arg0, arg1, arg2, arg3)
}
// GetSet mocks base method.
func (m *MockUniversalClient) GetSet(arg0 context.Context, arg1 string, arg2 interface{}) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSet", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// GetSet indicates an expected call of GetSet.
func (mr *MockUniversalClientMockRecorder) GetSet(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSet", reflect.TypeOf((*MockUniversalClient)(nil).GetSet), arg0, arg1, arg2)
}
// HDel mocks base method.
func (m *MockUniversalClient) HDel(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "HDel", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// HDel indicates an expected call of HDel.
func (mr *MockUniversalClientMockRecorder) HDel(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HDel", reflect.TypeOf((*MockUniversalClient)(nil).HDel), varargs...)
}
// HExists mocks base method.
func (m *MockUniversalClient) HExists(arg0 context.Context, arg1, arg2 string) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HExists", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// HExists indicates an expected call of HExists.
func (mr *MockUniversalClientMockRecorder) HExists(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HExists", reflect.TypeOf((*MockUniversalClient)(nil).HExists), arg0, arg1, arg2)
}
// HGet mocks base method.
func (m *MockUniversalClient) HGet(arg0 context.Context, arg1, arg2 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HGet", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// HGet indicates an expected call of HGet.
func (mr *MockUniversalClientMockRecorder) HGet(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HGet", reflect.TypeOf((*MockUniversalClient)(nil).HGet), arg0, arg1, arg2)
}
// HGetAll mocks base method.
func (m *MockUniversalClient) HGetAll(arg0 context.Context, arg1 string) *redis.StringStringMapCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HGetAll", arg0, arg1)
ret0, _ := ret[0].(*redis.StringStringMapCmd)
return ret0
}
// HGetAll indicates an expected call of HGetAll.
func (mr *MockUniversalClientMockRecorder) HGetAll(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HGetAll", reflect.TypeOf((*MockUniversalClient)(nil).HGetAll), arg0, arg1)
}
// HIncrBy mocks base method.
func (m *MockUniversalClient) HIncrBy(arg0 context.Context, arg1, arg2 string, arg3 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HIncrBy", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// HIncrBy indicates an expected call of HIncrBy.
func (mr *MockUniversalClientMockRecorder) HIncrBy(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HIncrBy", reflect.TypeOf((*MockUniversalClient)(nil).HIncrBy), arg0, arg1, arg2, arg3)
}
// HIncrByFloat mocks base method.
func (m *MockUniversalClient) HIncrByFloat(arg0 context.Context, arg1, arg2 string, arg3 float64) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HIncrByFloat", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// HIncrByFloat indicates an expected call of HIncrByFloat.
func (mr *MockUniversalClientMockRecorder) HIncrByFloat(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HIncrByFloat", reflect.TypeOf((*MockUniversalClient)(nil).HIncrByFloat), arg0, arg1, arg2, arg3)
}
// HKeys mocks base method.
func (m *MockUniversalClient) HKeys(arg0 context.Context, arg1 string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HKeys", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// HKeys indicates an expected call of HKeys.
func (mr *MockUniversalClientMockRecorder) HKeys(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HKeys", reflect.TypeOf((*MockUniversalClient)(nil).HKeys), arg0, arg1)
}
// HLen mocks base method.
func (m *MockUniversalClient) HLen(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HLen", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// HLen indicates an expected call of HLen.
func (mr *MockUniversalClientMockRecorder) HLen(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HLen", reflect.TypeOf((*MockUniversalClient)(nil).HLen), arg0, arg1)
}
// HMGet mocks base method.
func (m *MockUniversalClient) HMGet(arg0 context.Context, arg1 string, arg2 ...string) *redis.SliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "HMGet", varargs...)
ret0, _ := ret[0].(*redis.SliceCmd)
return ret0
}
// HMGet indicates an expected call of HMGet.
func (mr *MockUniversalClientMockRecorder) HMGet(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HMGet", reflect.TypeOf((*MockUniversalClient)(nil).HMGet), varargs...)
}
// HMSet mocks base method.
func (m *MockUniversalClient) HMSet(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.BoolCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "HMSet", varargs...)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// HMSet indicates an expected call of HMSet.
func (mr *MockUniversalClientMockRecorder) HMSet(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HMSet", reflect.TypeOf((*MockUniversalClient)(nil).HMSet), varargs...)
}
// HRandField mocks base method.
func (m *MockUniversalClient) HRandField(arg0 context.Context, arg1 string, arg2 int, arg3 bool) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HRandField", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// HRandField indicates an expected call of HRandField.
func (mr *MockUniversalClientMockRecorder) HRandField(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HRandField", reflect.TypeOf((*MockUniversalClient)(nil).HRandField), arg0, arg1, arg2, arg3)
}
// HScan mocks base method.
func (m *MockUniversalClient) HScan(arg0 context.Context, arg1 string, arg2 uint64, arg3 string, arg4 int64) *redis.ScanCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HScan", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.ScanCmd)
return ret0
}
// HScan indicates an expected call of HScan.
func (mr *MockUniversalClientMockRecorder) HScan(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HScan", reflect.TypeOf((*MockUniversalClient)(nil).HScan), arg0, arg1, arg2, arg3, arg4)
}
// HSet mocks base method.
func (m *MockUniversalClient) HSet(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "HSet", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// HSet indicates an expected call of HSet.
func (mr *MockUniversalClientMockRecorder) HSet(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HSet", reflect.TypeOf((*MockUniversalClient)(nil).HSet), varargs...)
}
// HSetNX mocks base method.
func (m *MockUniversalClient) HSetNX(arg0 context.Context, arg1, arg2 string, arg3 interface{}) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HSetNX", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// HSetNX indicates an expected call of HSetNX.
func (mr *MockUniversalClientMockRecorder) HSetNX(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HSetNX", reflect.TypeOf((*MockUniversalClient)(nil).HSetNX), arg0, arg1, arg2, arg3)
}
// HVals mocks base method.
func (m *MockUniversalClient) HVals(arg0 context.Context, arg1 string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HVals", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// HVals indicates an expected call of HVals.
func (mr *MockUniversalClientMockRecorder) HVals(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HVals", reflect.TypeOf((*MockUniversalClient)(nil).HVals), arg0, arg1)
}
// Incr mocks base method.
func (m *MockUniversalClient) Incr(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Incr", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Incr indicates an expected call of Incr.
func (mr *MockUniversalClientMockRecorder) Incr(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Incr", reflect.TypeOf((*MockUniversalClient)(nil).Incr), arg0, arg1)
}
// IncrBy mocks base method.
func (m *MockUniversalClient) IncrBy(arg0 context.Context, arg1 string, arg2 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IncrBy", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// IncrBy indicates an expected call of IncrBy.
func (mr *MockUniversalClientMockRecorder) IncrBy(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrBy", reflect.TypeOf((*MockUniversalClient)(nil).IncrBy), arg0, arg1, arg2)
}
// IncrByFloat mocks base method.
func (m *MockUniversalClient) IncrByFloat(arg0 context.Context, arg1 string, arg2 float64) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IncrByFloat", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// IncrByFloat indicates an expected call of IncrByFloat.
func (mr *MockUniversalClientMockRecorder) IncrByFloat(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrByFloat", reflect.TypeOf((*MockUniversalClient)(nil).IncrByFloat), arg0, arg1, arg2)
}
// Info mocks base method.
func (m *MockUniversalClient) Info(arg0 context.Context, arg1 ...string) *redis.StringCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Info", varargs...)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// Info indicates an expected call of Info.
func (mr *MockUniversalClientMockRecorder) Info(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockUniversalClient)(nil).Info), varargs...)
}
// Keys mocks base method.
func (m *MockUniversalClient) Keys(arg0 context.Context, arg1 string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Keys", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// Keys indicates an expected call of Keys.
func (mr *MockUniversalClientMockRecorder) Keys(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Keys", reflect.TypeOf((*MockUniversalClient)(nil).Keys), arg0, arg1)
}
// LIndex mocks base method.
func (m *MockUniversalClient) LIndex(arg0 context.Context, arg1 string, arg2 int64) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LIndex", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// LIndex indicates an expected call of LIndex.
func (mr *MockUniversalClientMockRecorder) LIndex(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LIndex", reflect.TypeOf((*MockUniversalClient)(nil).LIndex), arg0, arg1, arg2)
}
// LInsert mocks base method.
func (m *MockUniversalClient) LInsert(arg0 context.Context, arg1, arg2 string, arg3, arg4 interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LInsert", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LInsert indicates an expected call of LInsert.
func (mr *MockUniversalClientMockRecorder) LInsert(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LInsert", reflect.TypeOf((*MockUniversalClient)(nil).LInsert), arg0, arg1, arg2, arg3, arg4)
}
// LInsertAfter mocks base method.
func (m *MockUniversalClient) LInsertAfter(arg0 context.Context, arg1 string, arg2, arg3 interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LInsertAfter", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LInsertAfter indicates an expected call of LInsertAfter.
func (mr *MockUniversalClientMockRecorder) LInsertAfter(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LInsertAfter", reflect.TypeOf((*MockUniversalClient)(nil).LInsertAfter), arg0, arg1, arg2, arg3)
}
// LInsertBefore mocks base method.
func (m *MockUniversalClient) LInsertBefore(arg0 context.Context, arg1 string, arg2, arg3 interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LInsertBefore", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LInsertBefore indicates an expected call of LInsertBefore.
func (mr *MockUniversalClientMockRecorder) LInsertBefore(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LInsertBefore", reflect.TypeOf((*MockUniversalClient)(nil).LInsertBefore), arg0, arg1, arg2, arg3)
}
// LLen mocks base method.
func (m *MockUniversalClient) LLen(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LLen", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LLen indicates an expected call of LLen.
func (mr *MockUniversalClientMockRecorder) LLen(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LLen", reflect.TypeOf((*MockUniversalClient)(nil).LLen), arg0, arg1)
}
// LMove mocks base method.
func (m *MockUniversalClient) LMove(arg0 context.Context, arg1, arg2, arg3, arg4 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LMove", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// LMove indicates an expected call of LMove.
func (mr *MockUniversalClientMockRecorder) LMove(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LMove", reflect.TypeOf((*MockUniversalClient)(nil).LMove), arg0, arg1, arg2, arg3, arg4)
}
// LPop mocks base method.
func (m *MockUniversalClient) LPop(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LPop", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// LPop indicates an expected call of LPop.
func (mr *MockUniversalClientMockRecorder) LPop(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LPop", reflect.TypeOf((*MockUniversalClient)(nil).LPop), arg0, arg1)
}
// LPopCount mocks base method.
func (m *MockUniversalClient) LPopCount(arg0 context.Context, arg1 string, arg2 int) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LPopCount", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// LPopCount indicates an expected call of LPopCount.
func (mr *MockUniversalClientMockRecorder) LPopCount(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LPopCount", reflect.TypeOf((*MockUniversalClient)(nil).LPopCount), arg0, arg1, arg2)
}
// LPos mocks base method.
func (m *MockUniversalClient) LPos(arg0 context.Context, arg1, arg2 string, arg3 redis.LPosArgs) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LPos", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LPos indicates an expected call of LPos.
func (mr *MockUniversalClientMockRecorder) LPos(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LPos", reflect.TypeOf((*MockUniversalClient)(nil).LPos), arg0, arg1, arg2, arg3)
}
// LPosCount mocks base method.
func (m *MockUniversalClient) LPosCount(arg0 context.Context, arg1, arg2 string, arg3 int64, arg4 redis.LPosArgs) *redis.IntSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LPosCount", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.IntSliceCmd)
return ret0
}
// LPosCount indicates an expected call of LPosCount.
func (mr *MockUniversalClientMockRecorder) LPosCount(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LPosCount", reflect.TypeOf((*MockUniversalClient)(nil).LPosCount), arg0, arg1, arg2, arg3, arg4)
}
// LPush mocks base method.
func (m *MockUniversalClient) LPush(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "LPush", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LPush indicates an expected call of LPush.
func (mr *MockUniversalClientMockRecorder) LPush(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LPush", reflect.TypeOf((*MockUniversalClient)(nil).LPush), varargs...)
}
// LPushX mocks base method.
func (m *MockUniversalClient) LPushX(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "LPushX", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LPushX indicates an expected call of LPushX.
func (mr *MockUniversalClientMockRecorder) LPushX(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LPushX", reflect.TypeOf((*MockUniversalClient)(nil).LPushX), varargs...)
}
// LRange mocks base method.
func (m *MockUniversalClient) LRange(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LRange", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// LRange indicates an expected call of LRange.
func (mr *MockUniversalClientMockRecorder) LRange(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LRange", reflect.TypeOf((*MockUniversalClient)(nil).LRange), arg0, arg1, arg2, arg3)
}
// LRem mocks base method.
func (m *MockUniversalClient) LRem(arg0 context.Context, arg1 string, arg2 int64, arg3 interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LRem", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LRem indicates an expected call of LRem.
func (mr *MockUniversalClientMockRecorder) LRem(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LRem", reflect.TypeOf((*MockUniversalClient)(nil).LRem), arg0, arg1, arg2, arg3)
}
// LSet mocks base method.
func (m *MockUniversalClient) LSet(arg0 context.Context, arg1 string, arg2 int64, arg3 interface{}) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LSet", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// LSet indicates an expected call of LSet.
func (mr *MockUniversalClientMockRecorder) LSet(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LSet", reflect.TypeOf((*MockUniversalClient)(nil).LSet), arg0, arg1, arg2, arg3)
}
// LTrim mocks base method.
func (m *MockUniversalClient) LTrim(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LTrim", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// LTrim indicates an expected call of LTrim.
func (mr *MockUniversalClientMockRecorder) LTrim(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LTrim", reflect.TypeOf((*MockUniversalClient)(nil).LTrim), arg0, arg1, arg2, arg3)
}
// LastSave mocks base method.
func (m *MockUniversalClient) LastSave(arg0 context.Context) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LastSave", arg0)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// LastSave indicates an expected call of LastSave.
func (mr *MockUniversalClientMockRecorder) LastSave(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastSave", reflect.TypeOf((*MockUniversalClient)(nil).LastSave), arg0)
}
// MGet mocks base method.
func (m *MockUniversalClient) MGet(arg0 context.Context, arg1 ...string) *redis.SliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "MGet", varargs...)
ret0, _ := ret[0].(*redis.SliceCmd)
return ret0
}
// MGet indicates an expected call of MGet.
func (mr *MockUniversalClientMockRecorder) MGet(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGet", reflect.TypeOf((*MockUniversalClient)(nil).MGet), varargs...)
}
// MSet mocks base method.
func (m *MockUniversalClient) MSet(arg0 context.Context, arg1 ...interface{}) *redis.StatusCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "MSet", varargs...)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// MSet indicates an expected call of MSet.
func (mr *MockUniversalClientMockRecorder) MSet(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MSet", reflect.TypeOf((*MockUniversalClient)(nil).MSet), varargs...)
}
// MSetNX mocks base method.
func (m *MockUniversalClient) MSetNX(arg0 context.Context, arg1 ...interface{}) *redis.BoolCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "MSetNX", varargs...)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// MSetNX indicates an expected call of MSetNX.
func (mr *MockUniversalClientMockRecorder) MSetNX(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MSetNX", reflect.TypeOf((*MockUniversalClient)(nil).MSetNX), varargs...)
}
// MemoryUsage mocks base method.
func (m *MockUniversalClient) MemoryUsage(arg0 context.Context, arg1 string, arg2 ...int) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "MemoryUsage", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// MemoryUsage indicates an expected call of MemoryUsage.
func (mr *MockUniversalClientMockRecorder) MemoryUsage(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MemoryUsage", reflect.TypeOf((*MockUniversalClient)(nil).MemoryUsage), varargs...)
}
// Migrate mocks base method.
func (m *MockUniversalClient) Migrate(arg0 context.Context, arg1, arg2, arg3 string, arg4 int, arg5 time.Duration) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Migrate", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Migrate indicates an expected call of Migrate.
func (mr *MockUniversalClientMockRecorder) Migrate(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockUniversalClient)(nil).Migrate), arg0, arg1, arg2, arg3, arg4, arg5)
}
// Move mocks base method.
func (m *MockUniversalClient) Move(arg0 context.Context, arg1 string, arg2 int) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Move", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// Move indicates an expected call of Move.
func (mr *MockUniversalClientMockRecorder) Move(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Move", reflect.TypeOf((*MockUniversalClient)(nil).Move), arg0, arg1, arg2)
}
// ObjectEncoding mocks base method.
func (m *MockUniversalClient) ObjectEncoding(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ObjectEncoding", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// ObjectEncoding indicates an expected call of ObjectEncoding.
func (mr *MockUniversalClientMockRecorder) ObjectEncoding(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ObjectEncoding", reflect.TypeOf((*MockUniversalClient)(nil).ObjectEncoding), arg0, arg1)
}
// ObjectIdleTime mocks base method.
func (m *MockUniversalClient) ObjectIdleTime(arg0 context.Context, arg1 string) *redis.DurationCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ObjectIdleTime", arg0, arg1)
ret0, _ := ret[0].(*redis.DurationCmd)
return ret0
}
// ObjectIdleTime indicates an expected call of ObjectIdleTime.
func (mr *MockUniversalClientMockRecorder) ObjectIdleTime(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ObjectIdleTime", reflect.TypeOf((*MockUniversalClient)(nil).ObjectIdleTime), arg0, arg1)
}
// ObjectRefCount mocks base method.
func (m *MockUniversalClient) ObjectRefCount(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ObjectRefCount", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ObjectRefCount indicates an expected call of ObjectRefCount.
func (mr *MockUniversalClientMockRecorder) ObjectRefCount(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ObjectRefCount", reflect.TypeOf((*MockUniversalClient)(nil).ObjectRefCount), arg0, arg1)
}
// PExpire mocks base method.
func (m *MockUniversalClient) PExpire(arg0 context.Context, arg1 string, arg2 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PExpire", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// PExpire indicates an expected call of PExpire.
func (mr *MockUniversalClientMockRecorder) PExpire(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PExpire", reflect.TypeOf((*MockUniversalClient)(nil).PExpire), arg0, arg1, arg2)
}
// PExpireAt mocks base method.
func (m *MockUniversalClient) PExpireAt(arg0 context.Context, arg1 string, arg2 time.Time) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PExpireAt", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// PExpireAt indicates an expected call of PExpireAt.
func (mr *MockUniversalClientMockRecorder) PExpireAt(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PExpireAt", reflect.TypeOf((*MockUniversalClient)(nil).PExpireAt), arg0, arg1, arg2)
}
// PFAdd mocks base method.
func (m *MockUniversalClient) PFAdd(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "PFAdd", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// PFAdd indicates an expected call of PFAdd.
func (mr *MockUniversalClientMockRecorder) PFAdd(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PFAdd", reflect.TypeOf((*MockUniversalClient)(nil).PFAdd), varargs...)
}
// PFCount mocks base method.
func (m *MockUniversalClient) PFCount(arg0 context.Context, arg1 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "PFCount", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// PFCount indicates an expected call of PFCount.
func (mr *MockUniversalClientMockRecorder) PFCount(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PFCount", reflect.TypeOf((*MockUniversalClient)(nil).PFCount), varargs...)
}
// PFMerge mocks base method.
func (m *MockUniversalClient) PFMerge(arg0 context.Context, arg1 string, arg2 ...string) *redis.StatusCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "PFMerge", varargs...)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// PFMerge indicates an expected call of PFMerge.
func (mr *MockUniversalClientMockRecorder) PFMerge(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PFMerge", reflect.TypeOf((*MockUniversalClient)(nil).PFMerge), varargs...)
}
// PSubscribe mocks base method.
func (m *MockUniversalClient) PSubscribe(arg0 context.Context, arg1 ...string) *redis.PubSub {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "PSubscribe", varargs...)
ret0, _ := ret[0].(*redis.PubSub)
return ret0
}
// PSubscribe indicates an expected call of PSubscribe.
func (mr *MockUniversalClientMockRecorder) PSubscribe(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PSubscribe", reflect.TypeOf((*MockUniversalClient)(nil).PSubscribe), varargs...)
}
// PTTL mocks base method.
func (m *MockUniversalClient) PTTL(arg0 context.Context, arg1 string) *redis.DurationCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PTTL", arg0, arg1)
ret0, _ := ret[0].(*redis.DurationCmd)
return ret0
}
// PTTL indicates an expected call of PTTL.
func (mr *MockUniversalClientMockRecorder) PTTL(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PTTL", reflect.TypeOf((*MockUniversalClient)(nil).PTTL), arg0, arg1)
}
// Persist mocks base method.
func (m *MockUniversalClient) Persist(arg0 context.Context, arg1 string) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Persist", arg0, arg1)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// Persist indicates an expected call of Persist.
func (mr *MockUniversalClientMockRecorder) Persist(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Persist", reflect.TypeOf((*MockUniversalClient)(nil).Persist), arg0, arg1)
}
// Ping mocks base method.
func (m *MockUniversalClient) Ping(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Ping", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Ping indicates an expected call of Ping.
func (mr *MockUniversalClientMockRecorder) Ping(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockUniversalClient)(nil).Ping), arg0)
}
// Pipeline mocks base method.
func (m *MockUniversalClient) Pipeline() redis.Pipeliner {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Pipeline")
ret0, _ := ret[0].(redis.Pipeliner)
return ret0
}
// Pipeline indicates an expected call of Pipeline.
func (mr *MockUniversalClientMockRecorder) Pipeline() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pipeline", reflect.TypeOf((*MockUniversalClient)(nil).Pipeline))
}
// Pipelined mocks base method.
func (m *MockUniversalClient) Pipelined(arg0 context.Context, arg1 func(redis.Pipeliner) error) ([]redis.Cmder, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Pipelined", arg0, arg1)
ret0, _ := ret[0].([]redis.Cmder)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Pipelined indicates an expected call of Pipelined.
func (mr *MockUniversalClientMockRecorder) Pipelined(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pipelined", reflect.TypeOf((*MockUniversalClient)(nil).Pipelined), arg0, arg1)
}
// PoolStats mocks base method.
func (m *MockUniversalClient) PoolStats() *redis.PoolStats {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PoolStats")
ret0, _ := ret[0].(*redis.PoolStats)
return ret0
}
// PoolStats indicates an expected call of PoolStats.
func (mr *MockUniversalClientMockRecorder) PoolStats() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PoolStats", reflect.TypeOf((*MockUniversalClient)(nil).PoolStats))
}
// Process mocks base method.
func (m *MockUniversalClient) Process(arg0 context.Context, arg1 redis.Cmder) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Process", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Process indicates an expected call of Process.
func (mr *MockUniversalClientMockRecorder) Process(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Process", reflect.TypeOf((*MockUniversalClient)(nil).Process), arg0, arg1)
}
// PubSubChannels mocks base method.
func (m *MockUniversalClient) PubSubChannels(arg0 context.Context, arg1 string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PubSubChannels", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// PubSubChannels indicates an expected call of PubSubChannels.
func (mr *MockUniversalClientMockRecorder) PubSubChannels(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PubSubChannels", reflect.TypeOf((*MockUniversalClient)(nil).PubSubChannels), arg0, arg1)
}
// PubSubNumPat mocks base method.
func (m *MockUniversalClient) PubSubNumPat(arg0 context.Context) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PubSubNumPat", arg0)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// PubSubNumPat indicates an expected call of PubSubNumPat.
func (mr *MockUniversalClientMockRecorder) PubSubNumPat(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PubSubNumPat", reflect.TypeOf((*MockUniversalClient)(nil).PubSubNumPat), arg0)
}
// PubSubNumSub mocks base method.
func (m *MockUniversalClient) PubSubNumSub(arg0 context.Context, arg1 ...string) *redis.StringIntMapCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "PubSubNumSub", varargs...)
ret0, _ := ret[0].(*redis.StringIntMapCmd)
return ret0
}
// PubSubNumSub indicates an expected call of PubSubNumSub.
func (mr *MockUniversalClientMockRecorder) PubSubNumSub(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PubSubNumSub", reflect.TypeOf((*MockUniversalClient)(nil).PubSubNumSub), varargs...)
}
// Publish mocks base method.
func (m *MockUniversalClient) Publish(arg0 context.Context, arg1 string, arg2 interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Publish", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Publish indicates an expected call of Publish.
func (mr *MockUniversalClientMockRecorder) Publish(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockUniversalClient)(nil).Publish), arg0, arg1, arg2)
}
// Quit mocks base method.
func (m *MockUniversalClient) Quit(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Quit", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Quit indicates an expected call of Quit.
func (mr *MockUniversalClientMockRecorder) Quit(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Quit", reflect.TypeOf((*MockUniversalClient)(nil).Quit), arg0)
}
// RPop mocks base method.
func (m *MockUniversalClient) RPop(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RPop", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// RPop indicates an expected call of RPop.
func (mr *MockUniversalClientMockRecorder) RPop(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPop", reflect.TypeOf((*MockUniversalClient)(nil).RPop), arg0, arg1)
}
// RPopCount mocks base method.
func (m *MockUniversalClient) RPopCount(arg0 context.Context, arg1 string, arg2 int) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RPopCount", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// RPopCount indicates an expected call of RPopCount.
func (mr *MockUniversalClientMockRecorder) RPopCount(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPopCount", reflect.TypeOf((*MockUniversalClient)(nil).RPopCount), arg0, arg1, arg2)
}
// RPopLPush mocks base method.
func (m *MockUniversalClient) RPopLPush(arg0 context.Context, arg1, arg2 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RPopLPush", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// RPopLPush indicates an expected call of RPopLPush.
func (mr *MockUniversalClientMockRecorder) RPopLPush(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPopLPush", reflect.TypeOf((*MockUniversalClient)(nil).RPopLPush), arg0, arg1, arg2)
}
// RPush mocks base method.
func (m *MockUniversalClient) RPush(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "RPush", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// RPush indicates an expected call of RPush.
func (mr *MockUniversalClientMockRecorder) RPush(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPush", reflect.TypeOf((*MockUniversalClient)(nil).RPush), varargs...)
}
// RPushX mocks base method.
func (m *MockUniversalClient) RPushX(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "RPushX", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// RPushX indicates an expected call of RPushX.
func (mr *MockUniversalClientMockRecorder) RPushX(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPushX", reflect.TypeOf((*MockUniversalClient)(nil).RPushX), varargs...)
}
// RandomKey mocks base method.
func (m *MockUniversalClient) RandomKey(arg0 context.Context) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RandomKey", arg0)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// RandomKey indicates an expected call of RandomKey.
func (mr *MockUniversalClientMockRecorder) RandomKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RandomKey", reflect.TypeOf((*MockUniversalClient)(nil).RandomKey), arg0)
}
// ReadOnly mocks base method.
func (m *MockUniversalClient) ReadOnly(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadOnly", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ReadOnly indicates an expected call of ReadOnly.
func (mr *MockUniversalClientMockRecorder) ReadOnly(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadOnly", reflect.TypeOf((*MockUniversalClient)(nil).ReadOnly), arg0)
}
// ReadWrite mocks base method.
func (m *MockUniversalClient) ReadWrite(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadWrite", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ReadWrite indicates an expected call of ReadWrite.
func (mr *MockUniversalClientMockRecorder) ReadWrite(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadWrite", reflect.TypeOf((*MockUniversalClient)(nil).ReadWrite), arg0)
}
// Rename mocks base method.
func (m *MockUniversalClient) Rename(arg0 context.Context, arg1, arg2 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rename", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Rename indicates an expected call of Rename.
func (mr *MockUniversalClientMockRecorder) Rename(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rename", reflect.TypeOf((*MockUniversalClient)(nil).Rename), arg0, arg1, arg2)
}
// RenameNX mocks base method.
func (m *MockUniversalClient) RenameNX(arg0 context.Context, arg1, arg2 string) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenameNX", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// RenameNX indicates an expected call of RenameNX.
func (mr *MockUniversalClientMockRecorder) RenameNX(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenameNX", reflect.TypeOf((*MockUniversalClient)(nil).RenameNX), arg0, arg1, arg2)
}
// Restore mocks base method.
func (m *MockUniversalClient) Restore(arg0 context.Context, arg1 string, arg2 time.Duration, arg3 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Restore", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Restore indicates an expected call of Restore.
func (mr *MockUniversalClientMockRecorder) Restore(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Restore", reflect.TypeOf((*MockUniversalClient)(nil).Restore), arg0, arg1, arg2, arg3)
}
// RestoreReplace mocks base method.
func (m *MockUniversalClient) RestoreReplace(arg0 context.Context, arg1 string, arg2 time.Duration, arg3 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RestoreReplace", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// RestoreReplace indicates an expected call of RestoreReplace.
func (mr *MockUniversalClientMockRecorder) RestoreReplace(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoreReplace", reflect.TypeOf((*MockUniversalClient)(nil).RestoreReplace), arg0, arg1, arg2, arg3)
}
// SAdd mocks base method.
func (m *MockUniversalClient) SAdd(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SAdd", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SAdd indicates an expected call of SAdd.
func (mr *MockUniversalClientMockRecorder) SAdd(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SAdd", reflect.TypeOf((*MockUniversalClient)(nil).SAdd), varargs...)
}
// SCard mocks base method.
func (m *MockUniversalClient) SCard(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SCard", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SCard indicates an expected call of SCard.
func (mr *MockUniversalClientMockRecorder) SCard(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SCard", reflect.TypeOf((*MockUniversalClient)(nil).SCard), arg0, arg1)
}
// SDiff mocks base method.
func (m *MockUniversalClient) SDiff(arg0 context.Context, arg1 ...string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SDiff", varargs...)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// SDiff indicates an expected call of SDiff.
func (mr *MockUniversalClientMockRecorder) SDiff(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SDiff", reflect.TypeOf((*MockUniversalClient)(nil).SDiff), varargs...)
}
// SDiffStore mocks base method.
func (m *MockUniversalClient) SDiffStore(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SDiffStore", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SDiffStore indicates an expected call of SDiffStore.
func (mr *MockUniversalClientMockRecorder) SDiffStore(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SDiffStore", reflect.TypeOf((*MockUniversalClient)(nil).SDiffStore), varargs...)
}
// SInter mocks base method.
func (m *MockUniversalClient) SInter(arg0 context.Context, arg1 ...string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SInter", varargs...)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// SInter indicates an expected call of SInter.
func (mr *MockUniversalClientMockRecorder) SInter(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SInter", reflect.TypeOf((*MockUniversalClient)(nil).SInter), varargs...)
}
// SInterStore mocks base method.
func (m *MockUniversalClient) SInterStore(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SInterStore", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SInterStore indicates an expected call of SInterStore.
func (mr *MockUniversalClientMockRecorder) SInterStore(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SInterStore", reflect.TypeOf((*MockUniversalClient)(nil).SInterStore), varargs...)
}
// SIsMember mocks base method.
func (m *MockUniversalClient) SIsMember(arg0 context.Context, arg1 string, arg2 interface{}) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SIsMember", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// SIsMember indicates an expected call of SIsMember.
func (mr *MockUniversalClientMockRecorder) SIsMember(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SIsMember", reflect.TypeOf((*MockUniversalClient)(nil).SIsMember), arg0, arg1, arg2)
}
// SMIsMember mocks base method.
func (m *MockUniversalClient) SMIsMember(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.BoolSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SMIsMember", varargs...)
ret0, _ := ret[0].(*redis.BoolSliceCmd)
return ret0
}
// SMIsMember indicates an expected call of SMIsMember.
func (mr *MockUniversalClientMockRecorder) SMIsMember(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMIsMember", reflect.TypeOf((*MockUniversalClient)(nil).SMIsMember), varargs...)
}
// SMembers mocks base method.
func (m *MockUniversalClient) SMembers(arg0 context.Context, arg1 string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SMembers", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// SMembers indicates an expected call of SMembers.
func (mr *MockUniversalClientMockRecorder) SMembers(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMembers", reflect.TypeOf((*MockUniversalClient)(nil).SMembers), arg0, arg1)
}
// SMembersMap mocks base method.
func (m *MockUniversalClient) SMembersMap(arg0 context.Context, arg1 string) *redis.StringStructMapCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SMembersMap", arg0, arg1)
ret0, _ := ret[0].(*redis.StringStructMapCmd)
return ret0
}
// SMembersMap indicates an expected call of SMembersMap.
func (mr *MockUniversalClientMockRecorder) SMembersMap(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMembersMap", reflect.TypeOf((*MockUniversalClient)(nil).SMembersMap), arg0, arg1)
}
// SMove mocks base method.
func (m *MockUniversalClient) SMove(arg0 context.Context, arg1, arg2 string, arg3 interface{}) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SMove", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// SMove indicates an expected call of SMove.
func (mr *MockUniversalClientMockRecorder) SMove(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMove", reflect.TypeOf((*MockUniversalClient)(nil).SMove), arg0, arg1, arg2, arg3)
}
// SPop mocks base method.
func (m *MockUniversalClient) SPop(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SPop", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// SPop indicates an expected call of SPop.
func (mr *MockUniversalClientMockRecorder) SPop(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SPop", reflect.TypeOf((*MockUniversalClient)(nil).SPop), arg0, arg1)
}
// SPopN mocks base method.
func (m *MockUniversalClient) SPopN(arg0 context.Context, arg1 string, arg2 int64) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SPopN", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// SPopN indicates an expected call of SPopN.
func (mr *MockUniversalClientMockRecorder) SPopN(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SPopN", reflect.TypeOf((*MockUniversalClient)(nil).SPopN), arg0, arg1, arg2)
}
// SRandMember mocks base method.
func (m *MockUniversalClient) SRandMember(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SRandMember", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// SRandMember indicates an expected call of SRandMember.
func (mr *MockUniversalClientMockRecorder) SRandMember(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SRandMember", reflect.TypeOf((*MockUniversalClient)(nil).SRandMember), arg0, arg1)
}
// SRandMemberN mocks base method.
func (m *MockUniversalClient) SRandMemberN(arg0 context.Context, arg1 string, arg2 int64) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SRandMemberN", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// SRandMemberN indicates an expected call of SRandMemberN.
func (mr *MockUniversalClientMockRecorder) SRandMemberN(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SRandMemberN", reflect.TypeOf((*MockUniversalClient)(nil).SRandMemberN), arg0, arg1, arg2)
}
// SRem mocks base method.
func (m *MockUniversalClient) SRem(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SRem", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SRem indicates an expected call of SRem.
func (mr *MockUniversalClientMockRecorder) SRem(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SRem", reflect.TypeOf((*MockUniversalClient)(nil).SRem), varargs...)
}
// SScan mocks base method.
func (m *MockUniversalClient) SScan(arg0 context.Context, arg1 string, arg2 uint64, arg3 string, arg4 int64) *redis.ScanCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SScan", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.ScanCmd)
return ret0
}
// SScan indicates an expected call of SScan.
func (mr *MockUniversalClientMockRecorder) SScan(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SScan", reflect.TypeOf((*MockUniversalClient)(nil).SScan), arg0, arg1, arg2, arg3, arg4)
}
// SUnion mocks base method.
func (m *MockUniversalClient) SUnion(arg0 context.Context, arg1 ...string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SUnion", varargs...)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// SUnion indicates an expected call of SUnion.
func (mr *MockUniversalClientMockRecorder) SUnion(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SUnion", reflect.TypeOf((*MockUniversalClient)(nil).SUnion), varargs...)
}
// SUnionStore mocks base method.
func (m *MockUniversalClient) SUnionStore(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SUnionStore", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SUnionStore indicates an expected call of SUnionStore.
func (mr *MockUniversalClientMockRecorder) SUnionStore(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SUnionStore", reflect.TypeOf((*MockUniversalClient)(nil).SUnionStore), varargs...)
}
// Save mocks base method.
func (m *MockUniversalClient) Save(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Save", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Save indicates an expected call of Save.
func (mr *MockUniversalClientMockRecorder) Save(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockUniversalClient)(nil).Save), arg0)
}
// Scan mocks base method.
func (m *MockUniversalClient) Scan(arg0 context.Context, arg1 uint64, arg2 string, arg3 int64) *redis.ScanCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Scan", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.ScanCmd)
return ret0
}
// Scan indicates an expected call of Scan.
func (mr *MockUniversalClientMockRecorder) Scan(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockUniversalClient)(nil).Scan), arg0, arg1, arg2, arg3)
}
// ScanType mocks base method.
func (m *MockUniversalClient) ScanType(arg0 context.Context, arg1 uint64, arg2 string, arg3 int64, arg4 string) *redis.ScanCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ScanType", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.ScanCmd)
return ret0
}
// ScanType indicates an expected call of ScanType.
func (mr *MockUniversalClientMockRecorder) ScanType(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScanType", reflect.TypeOf((*MockUniversalClient)(nil).ScanType), arg0, arg1, arg2, arg3, arg4)
}
// ScriptExists mocks base method.
func (m *MockUniversalClient) ScriptExists(arg0 context.Context, arg1 ...string) *redis.BoolSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ScriptExists", varargs...)
ret0, _ := ret[0].(*redis.BoolSliceCmd)
return ret0
}
// ScriptExists indicates an expected call of ScriptExists.
func (mr *MockUniversalClientMockRecorder) ScriptExists(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScriptExists", reflect.TypeOf((*MockUniversalClient)(nil).ScriptExists), varargs...)
}
// ScriptFlush mocks base method.
func (m *MockUniversalClient) ScriptFlush(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ScriptFlush", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ScriptFlush indicates an expected call of ScriptFlush.
func (mr *MockUniversalClientMockRecorder) ScriptFlush(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScriptFlush", reflect.TypeOf((*MockUniversalClient)(nil).ScriptFlush), arg0)
}
// ScriptKill mocks base method.
func (m *MockUniversalClient) ScriptKill(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ScriptKill", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ScriptKill indicates an expected call of ScriptKill.
func (mr *MockUniversalClientMockRecorder) ScriptKill(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScriptKill", reflect.TypeOf((*MockUniversalClient)(nil).ScriptKill), arg0)
}
// ScriptLoad mocks base method.
func (m *MockUniversalClient) ScriptLoad(arg0 context.Context, arg1 string) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ScriptLoad", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// ScriptLoad indicates an expected call of ScriptLoad.
func (mr *MockUniversalClientMockRecorder) ScriptLoad(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScriptLoad", reflect.TypeOf((*MockUniversalClient)(nil).ScriptLoad), arg0, arg1)
}
// Set mocks base method.
func (m *MockUniversalClient) Set(arg0 context.Context, arg1 string, arg2 interface{}, arg3 time.Duration) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Set", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Set indicates an expected call of Set.
func (mr *MockUniversalClientMockRecorder) Set(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockUniversalClient)(nil).Set), arg0, arg1, arg2, arg3)
}
// SetArgs mocks base method.
func (m *MockUniversalClient) SetArgs(arg0 context.Context, arg1 string, arg2 interface{}, arg3 redis.SetArgs) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetArgs", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// SetArgs indicates an expected call of SetArgs.
func (mr *MockUniversalClientMockRecorder) SetArgs(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetArgs", reflect.TypeOf((*MockUniversalClient)(nil).SetArgs), arg0, arg1, arg2, arg3)
}
// SetBit mocks base method.
func (m *MockUniversalClient) SetBit(arg0 context.Context, arg1 string, arg2 int64, arg3 int) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetBit", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SetBit indicates an expected call of SetBit.
func (mr *MockUniversalClientMockRecorder) SetBit(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBit", reflect.TypeOf((*MockUniversalClient)(nil).SetBit), arg0, arg1, arg2, arg3)
}
// SetEX mocks base method.
func (m *MockUniversalClient) SetEX(arg0 context.Context, arg1 string, arg2 interface{}, arg3 time.Duration) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetEX", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// SetEX indicates an expected call of SetEX.
func (mr *MockUniversalClientMockRecorder) SetEX(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetEX", reflect.TypeOf((*MockUniversalClient)(nil).SetEX), arg0, arg1, arg2, arg3)
}
// SetNX mocks base method.
func (m *MockUniversalClient) SetNX(arg0 context.Context, arg1 string, arg2 interface{}, arg3 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetNX", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// SetNX indicates an expected call of SetNX.
func (mr *MockUniversalClientMockRecorder) SetNX(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNX", reflect.TypeOf((*MockUniversalClient)(nil).SetNX), arg0, arg1, arg2, arg3)
}
// SetRange mocks base method.
func (m *MockUniversalClient) SetRange(arg0 context.Context, arg1 string, arg2 int64, arg3 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetRange", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SetRange indicates an expected call of SetRange.
func (mr *MockUniversalClientMockRecorder) SetRange(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRange", reflect.TypeOf((*MockUniversalClient)(nil).SetRange), arg0, arg1, arg2, arg3)
}
// SetXX mocks base method.
func (m *MockUniversalClient) SetXX(arg0 context.Context, arg1 string, arg2 interface{}, arg3 time.Duration) *redis.BoolCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetXX", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.BoolCmd)
return ret0
}
// SetXX indicates an expected call of SetXX.
func (mr *MockUniversalClientMockRecorder) SetXX(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetXX", reflect.TypeOf((*MockUniversalClient)(nil).SetXX), arg0, arg1, arg2, arg3)
}
// Shutdown mocks base method.
func (m *MockUniversalClient) Shutdown(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Shutdown", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Shutdown indicates an expected call of Shutdown.
func (mr *MockUniversalClientMockRecorder) Shutdown(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockUniversalClient)(nil).Shutdown), arg0)
}
// ShutdownNoSave mocks base method.
func (m *MockUniversalClient) ShutdownNoSave(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ShutdownNoSave", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ShutdownNoSave indicates an expected call of ShutdownNoSave.
func (mr *MockUniversalClientMockRecorder) ShutdownNoSave(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShutdownNoSave", reflect.TypeOf((*MockUniversalClient)(nil).ShutdownNoSave), arg0)
}
// ShutdownSave mocks base method.
func (m *MockUniversalClient) ShutdownSave(arg0 context.Context) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ShutdownSave", arg0)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// ShutdownSave indicates an expected call of ShutdownSave.
func (mr *MockUniversalClientMockRecorder) ShutdownSave(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShutdownSave", reflect.TypeOf((*MockUniversalClient)(nil).ShutdownSave), arg0)
}
// SlaveOf mocks base method.
func (m *MockUniversalClient) SlaveOf(arg0 context.Context, arg1, arg2 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SlaveOf", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// SlaveOf indicates an expected call of SlaveOf.
func (mr *MockUniversalClientMockRecorder) SlaveOf(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SlaveOf", reflect.TypeOf((*MockUniversalClient)(nil).SlaveOf), arg0, arg1, arg2)
}
// Sort mocks base method.
func (m *MockUniversalClient) Sort(arg0 context.Context, arg1 string, arg2 *redis.Sort) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Sort", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// Sort indicates an expected call of Sort.
func (mr *MockUniversalClientMockRecorder) Sort(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sort", reflect.TypeOf((*MockUniversalClient)(nil).Sort), arg0, arg1, arg2)
}
// SortInterfaces mocks base method.
func (m *MockUniversalClient) SortInterfaces(arg0 context.Context, arg1 string, arg2 *redis.Sort) *redis.SliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SortInterfaces", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.SliceCmd)
return ret0
}
// SortInterfaces indicates an expected call of SortInterfaces.
func (mr *MockUniversalClientMockRecorder) SortInterfaces(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SortInterfaces", reflect.TypeOf((*MockUniversalClient)(nil).SortInterfaces), arg0, arg1, arg2)
}
// SortStore mocks base method.
func (m *MockUniversalClient) SortStore(arg0 context.Context, arg1, arg2 string, arg3 *redis.Sort) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SortStore", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// SortStore indicates an expected call of SortStore.
func (mr *MockUniversalClientMockRecorder) SortStore(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SortStore", reflect.TypeOf((*MockUniversalClient)(nil).SortStore), arg0, arg1, arg2, arg3)
}
// StrLen mocks base method.
func (m *MockUniversalClient) StrLen(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StrLen", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// StrLen indicates an expected call of StrLen.
func (mr *MockUniversalClientMockRecorder) StrLen(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StrLen", reflect.TypeOf((*MockUniversalClient)(nil).StrLen), arg0, arg1)
}
// Subscribe mocks base method.
func (m *MockUniversalClient) Subscribe(arg0 context.Context, arg1 ...string) *redis.PubSub {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Subscribe", varargs...)
ret0, _ := ret[0].(*redis.PubSub)
return ret0
}
// Subscribe indicates an expected call of Subscribe.
func (mr *MockUniversalClientMockRecorder) Subscribe(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockUniversalClient)(nil).Subscribe), varargs...)
}
// TTL mocks base method.
func (m *MockUniversalClient) TTL(arg0 context.Context, arg1 string) *redis.DurationCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TTL", arg0, arg1)
ret0, _ := ret[0].(*redis.DurationCmd)
return ret0
}
// TTL indicates an expected call of TTL.
func (mr *MockUniversalClientMockRecorder) TTL(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TTL", reflect.TypeOf((*MockUniversalClient)(nil).TTL), arg0, arg1)
}
// Time mocks base method.
func (m *MockUniversalClient) Time(arg0 context.Context) *redis.TimeCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Time", arg0)
ret0, _ := ret[0].(*redis.TimeCmd)
return ret0
}
// Time indicates an expected call of Time.
func (mr *MockUniversalClientMockRecorder) Time(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Time", reflect.TypeOf((*MockUniversalClient)(nil).Time), arg0)
}
// Touch mocks base method.
func (m *MockUniversalClient) Touch(arg0 context.Context, arg1 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Touch", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Touch indicates an expected call of Touch.
func (mr *MockUniversalClientMockRecorder) Touch(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Touch", reflect.TypeOf((*MockUniversalClient)(nil).Touch), varargs...)
}
// TxPipeline mocks base method.
func (m *MockUniversalClient) TxPipeline() redis.Pipeliner {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TxPipeline")
ret0, _ := ret[0].(redis.Pipeliner)
return ret0
}
// TxPipeline indicates an expected call of TxPipeline.
func (mr *MockUniversalClientMockRecorder) TxPipeline() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxPipeline", reflect.TypeOf((*MockUniversalClient)(nil).TxPipeline))
}
// TxPipelined mocks base method.
func (m *MockUniversalClient) TxPipelined(arg0 context.Context, arg1 func(redis.Pipeliner) error) ([]redis.Cmder, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TxPipelined", arg0, arg1)
ret0, _ := ret[0].([]redis.Cmder)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// TxPipelined indicates an expected call of TxPipelined.
func (mr *MockUniversalClientMockRecorder) TxPipelined(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxPipelined", reflect.TypeOf((*MockUniversalClient)(nil).TxPipelined), arg0, arg1)
}
// Type mocks base method.
func (m *MockUniversalClient) Type(arg0 context.Context, arg1 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Type", arg0, arg1)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// Type indicates an expected call of Type.
func (mr *MockUniversalClientMockRecorder) Type(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*MockUniversalClient)(nil).Type), arg0, arg1)
}
// Unlink mocks base method.
func (m *MockUniversalClient) Unlink(arg0 context.Context, arg1 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Unlink", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// Unlink indicates an expected call of Unlink.
func (mr *MockUniversalClientMockRecorder) Unlink(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlink", reflect.TypeOf((*MockUniversalClient)(nil).Unlink), varargs...)
}
// Watch mocks base method.
func (m *MockUniversalClient) Watch(arg0 context.Context, arg1 func(*redis.Tx) error, arg2 ...string) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Watch", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Watch indicates an expected call of Watch.
func (mr *MockUniversalClientMockRecorder) Watch(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockUniversalClient)(nil).Watch), varargs...)
}
// XAck mocks base method.
func (m *MockUniversalClient) XAck(arg0 context.Context, arg1, arg2 string, arg3 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "XAck", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XAck indicates an expected call of XAck.
func (mr *MockUniversalClientMockRecorder) XAck(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XAck", reflect.TypeOf((*MockUniversalClient)(nil).XAck), varargs...)
}
// XAdd mocks base method.
func (m *MockUniversalClient) XAdd(arg0 context.Context, arg1 *redis.XAddArgs) *redis.StringCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XAdd", arg0, arg1)
ret0, _ := ret[0].(*redis.StringCmd)
return ret0
}
// XAdd indicates an expected call of XAdd.
func (mr *MockUniversalClientMockRecorder) XAdd(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XAdd", reflect.TypeOf((*MockUniversalClient)(nil).XAdd), arg0, arg1)
}
// XAutoClaim mocks base method.
func (m *MockUniversalClient) XAutoClaim(arg0 context.Context, arg1 *redis.XAutoClaimArgs) *redis.XAutoClaimCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XAutoClaim", arg0, arg1)
ret0, _ := ret[0].(*redis.XAutoClaimCmd)
return ret0
}
// XAutoClaim indicates an expected call of XAutoClaim.
func (mr *MockUniversalClientMockRecorder) XAutoClaim(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XAutoClaim", reflect.TypeOf((*MockUniversalClient)(nil).XAutoClaim), arg0, arg1)
}
// XAutoClaimJustID mocks base method.
func (m *MockUniversalClient) XAutoClaimJustID(arg0 context.Context, arg1 *redis.XAutoClaimArgs) *redis.XAutoClaimJustIDCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XAutoClaimJustID", arg0, arg1)
ret0, _ := ret[0].(*redis.XAutoClaimJustIDCmd)
return ret0
}
// XAutoClaimJustID indicates an expected call of XAutoClaimJustID.
func (mr *MockUniversalClientMockRecorder) XAutoClaimJustID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XAutoClaimJustID", reflect.TypeOf((*MockUniversalClient)(nil).XAutoClaimJustID), arg0, arg1)
}
// XClaim mocks base method.
func (m *MockUniversalClient) XClaim(arg0 context.Context, arg1 *redis.XClaimArgs) *redis.XMessageSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XClaim", arg0, arg1)
ret0, _ := ret[0].(*redis.XMessageSliceCmd)
return ret0
}
// XClaim indicates an expected call of XClaim.
func (mr *MockUniversalClientMockRecorder) XClaim(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XClaim", reflect.TypeOf((*MockUniversalClient)(nil).XClaim), arg0, arg1)
}
// XClaimJustID mocks base method.
func (m *MockUniversalClient) XClaimJustID(arg0 context.Context, arg1 *redis.XClaimArgs) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XClaimJustID", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// XClaimJustID indicates an expected call of XClaimJustID.
func (mr *MockUniversalClientMockRecorder) XClaimJustID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XClaimJustID", reflect.TypeOf((*MockUniversalClient)(nil).XClaimJustID), arg0, arg1)
}
// XDel mocks base method.
func (m *MockUniversalClient) XDel(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "XDel", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XDel indicates an expected call of XDel.
func (mr *MockUniversalClientMockRecorder) XDel(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XDel", reflect.TypeOf((*MockUniversalClient)(nil).XDel), varargs...)
}
// XGroupCreate mocks base method.
func (m *MockUniversalClient) XGroupCreate(arg0 context.Context, arg1, arg2, arg3 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XGroupCreate", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// XGroupCreate indicates an expected call of XGroupCreate.
func (mr *MockUniversalClientMockRecorder) XGroupCreate(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XGroupCreate", reflect.TypeOf((*MockUniversalClient)(nil).XGroupCreate), arg0, arg1, arg2, arg3)
}
// XGroupCreateConsumer mocks base method.
func (m *MockUniversalClient) XGroupCreateConsumer(arg0 context.Context, arg1, arg2, arg3 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XGroupCreateConsumer", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XGroupCreateConsumer indicates an expected call of XGroupCreateConsumer.
func (mr *MockUniversalClientMockRecorder) XGroupCreateConsumer(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XGroupCreateConsumer", reflect.TypeOf((*MockUniversalClient)(nil).XGroupCreateConsumer), arg0, arg1, arg2, arg3)
}
// XGroupCreateMkStream mocks base method.
func (m *MockUniversalClient) XGroupCreateMkStream(arg0 context.Context, arg1, arg2, arg3 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XGroupCreateMkStream", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// XGroupCreateMkStream indicates an expected call of XGroupCreateMkStream.
func (mr *MockUniversalClientMockRecorder) XGroupCreateMkStream(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XGroupCreateMkStream", reflect.TypeOf((*MockUniversalClient)(nil).XGroupCreateMkStream), arg0, arg1, arg2, arg3)
}
// XGroupDelConsumer mocks base method.
func (m *MockUniversalClient) XGroupDelConsumer(arg0 context.Context, arg1, arg2, arg3 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XGroupDelConsumer", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XGroupDelConsumer indicates an expected call of XGroupDelConsumer.
func (mr *MockUniversalClientMockRecorder) XGroupDelConsumer(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XGroupDelConsumer", reflect.TypeOf((*MockUniversalClient)(nil).XGroupDelConsumer), arg0, arg1, arg2, arg3)
}
// XGroupDestroy mocks base method.
func (m *MockUniversalClient) XGroupDestroy(arg0 context.Context, arg1, arg2 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XGroupDestroy", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XGroupDestroy indicates an expected call of XGroupDestroy.
func (mr *MockUniversalClientMockRecorder) XGroupDestroy(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XGroupDestroy", reflect.TypeOf((*MockUniversalClient)(nil).XGroupDestroy), arg0, arg1, arg2)
}
// XGroupSetID mocks base method.
func (m *MockUniversalClient) XGroupSetID(arg0 context.Context, arg1, arg2, arg3 string) *redis.StatusCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XGroupSetID", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StatusCmd)
return ret0
}
// XGroupSetID indicates an expected call of XGroupSetID.
func (mr *MockUniversalClientMockRecorder) XGroupSetID(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XGroupSetID", reflect.TypeOf((*MockUniversalClient)(nil).XGroupSetID), arg0, arg1, arg2, arg3)
}
// XInfoConsumers mocks base method.
func (m *MockUniversalClient) XInfoConsumers(arg0 context.Context, arg1, arg2 string) *redis.XInfoConsumersCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XInfoConsumers", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.XInfoConsumersCmd)
return ret0
}
// XInfoConsumers indicates an expected call of XInfoConsumers.
func (mr *MockUniversalClientMockRecorder) XInfoConsumers(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XInfoConsumers", reflect.TypeOf((*MockUniversalClient)(nil).XInfoConsumers), arg0, arg1, arg2)
}
// XInfoGroups mocks base method.
func (m *MockUniversalClient) XInfoGroups(arg0 context.Context, arg1 string) *redis.XInfoGroupsCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XInfoGroups", arg0, arg1)
ret0, _ := ret[0].(*redis.XInfoGroupsCmd)
return ret0
}
// XInfoGroups indicates an expected call of XInfoGroups.
func (mr *MockUniversalClientMockRecorder) XInfoGroups(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XInfoGroups", reflect.TypeOf((*MockUniversalClient)(nil).XInfoGroups), arg0, arg1)
}
// XInfoStream mocks base method.
func (m *MockUniversalClient) XInfoStream(arg0 context.Context, arg1 string) *redis.XInfoStreamCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XInfoStream", arg0, arg1)
ret0, _ := ret[0].(*redis.XInfoStreamCmd)
return ret0
}
// XInfoStream indicates an expected call of XInfoStream.
func (mr *MockUniversalClientMockRecorder) XInfoStream(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XInfoStream", reflect.TypeOf((*MockUniversalClient)(nil).XInfoStream), arg0, arg1)
}
// XInfoStreamFull mocks base method.
func (m *MockUniversalClient) XInfoStreamFull(arg0 context.Context, arg1 string, arg2 int) *redis.XInfoStreamFullCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XInfoStreamFull", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.XInfoStreamFullCmd)
return ret0
}
// XInfoStreamFull indicates an expected call of XInfoStreamFull.
func (mr *MockUniversalClientMockRecorder) XInfoStreamFull(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XInfoStreamFull", reflect.TypeOf((*MockUniversalClient)(nil).XInfoStreamFull), arg0, arg1, arg2)
}
// XLen mocks base method.
func (m *MockUniversalClient) XLen(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XLen", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XLen indicates an expected call of XLen.
func (mr *MockUniversalClientMockRecorder) XLen(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XLen", reflect.TypeOf((*MockUniversalClient)(nil).XLen), arg0, arg1)
}
// XPending mocks base method.
func (m *MockUniversalClient) XPending(arg0 context.Context, arg1, arg2 string) *redis.XPendingCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XPending", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.XPendingCmd)
return ret0
}
// XPending indicates an expected call of XPending.
func (mr *MockUniversalClientMockRecorder) XPending(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XPending", reflect.TypeOf((*MockUniversalClient)(nil).XPending), arg0, arg1, arg2)
}
// XPendingExt mocks base method.
func (m *MockUniversalClient) XPendingExt(arg0 context.Context, arg1 *redis.XPendingExtArgs) *redis.XPendingExtCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XPendingExt", arg0, arg1)
ret0, _ := ret[0].(*redis.XPendingExtCmd)
return ret0
}
// XPendingExt indicates an expected call of XPendingExt.
func (mr *MockUniversalClientMockRecorder) XPendingExt(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XPendingExt", reflect.TypeOf((*MockUniversalClient)(nil).XPendingExt), arg0, arg1)
}
// XRange mocks base method.
func (m *MockUniversalClient) XRange(arg0 context.Context, arg1, arg2, arg3 string) *redis.XMessageSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XRange", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.XMessageSliceCmd)
return ret0
}
// XRange indicates an expected call of XRange.
func (mr *MockUniversalClientMockRecorder) XRange(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XRange", reflect.TypeOf((*MockUniversalClient)(nil).XRange), arg0, arg1, arg2, arg3)
}
// XRangeN mocks base method.
func (m *MockUniversalClient) XRangeN(arg0 context.Context, arg1, arg2, arg3 string, arg4 int64) *redis.XMessageSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XRangeN", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.XMessageSliceCmd)
return ret0
}
// XRangeN indicates an expected call of XRangeN.
func (mr *MockUniversalClientMockRecorder) XRangeN(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XRangeN", reflect.TypeOf((*MockUniversalClient)(nil).XRangeN), arg0, arg1, arg2, arg3, arg4)
}
// XRead mocks base method.
func (m *MockUniversalClient) XRead(arg0 context.Context, arg1 *redis.XReadArgs) *redis.XStreamSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XRead", arg0, arg1)
ret0, _ := ret[0].(*redis.XStreamSliceCmd)
return ret0
}
// XRead indicates an expected call of XRead.
func (mr *MockUniversalClientMockRecorder) XRead(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XRead", reflect.TypeOf((*MockUniversalClient)(nil).XRead), arg0, arg1)
}
// XReadGroup mocks base method.
func (m *MockUniversalClient) XReadGroup(arg0 context.Context, arg1 *redis.XReadGroupArgs) *redis.XStreamSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XReadGroup", arg0, arg1)
ret0, _ := ret[0].(*redis.XStreamSliceCmd)
return ret0
}
// XReadGroup indicates an expected call of XReadGroup.
func (mr *MockUniversalClientMockRecorder) XReadGroup(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XReadGroup", reflect.TypeOf((*MockUniversalClient)(nil).XReadGroup), arg0, arg1)
}
// XReadStreams mocks base method.
func (m *MockUniversalClient) XReadStreams(arg0 context.Context, arg1 ...string) *redis.XStreamSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "XReadStreams", varargs...)
ret0, _ := ret[0].(*redis.XStreamSliceCmd)
return ret0
}
// XReadStreams indicates an expected call of XReadStreams.
func (mr *MockUniversalClientMockRecorder) XReadStreams(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XReadStreams", reflect.TypeOf((*MockUniversalClient)(nil).XReadStreams), varargs...)
}
// XRevRange mocks base method.
func (m *MockUniversalClient) XRevRange(arg0 context.Context, arg1, arg2, arg3 string) *redis.XMessageSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XRevRange", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.XMessageSliceCmd)
return ret0
}
// XRevRange indicates an expected call of XRevRange.
func (mr *MockUniversalClientMockRecorder) XRevRange(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XRevRange", reflect.TypeOf((*MockUniversalClient)(nil).XRevRange), arg0, arg1, arg2, arg3)
}
// XRevRangeN mocks base method.
func (m *MockUniversalClient) XRevRangeN(arg0 context.Context, arg1, arg2, arg3 string, arg4 int64) *redis.XMessageSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XRevRangeN", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.XMessageSliceCmd)
return ret0
}
// XRevRangeN indicates an expected call of XRevRangeN.
func (mr *MockUniversalClientMockRecorder) XRevRangeN(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XRevRangeN", reflect.TypeOf((*MockUniversalClient)(nil).XRevRangeN), arg0, arg1, arg2, arg3, arg4)
}
// XTrim mocks base method.
func (m *MockUniversalClient) XTrim(arg0 context.Context, arg1 string, arg2 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XTrim", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XTrim indicates an expected call of XTrim.
func (mr *MockUniversalClientMockRecorder) XTrim(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XTrim", reflect.TypeOf((*MockUniversalClient)(nil).XTrim), arg0, arg1, arg2)
}
// XTrimApprox mocks base method.
func (m *MockUniversalClient) XTrimApprox(arg0 context.Context, arg1 string, arg2 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XTrimApprox", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XTrimApprox indicates an expected call of XTrimApprox.
func (mr *MockUniversalClientMockRecorder) XTrimApprox(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XTrimApprox", reflect.TypeOf((*MockUniversalClient)(nil).XTrimApprox), arg0, arg1, arg2)
}
// XTrimMaxLen mocks base method.
func (m *MockUniversalClient) XTrimMaxLen(arg0 context.Context, arg1 string, arg2 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XTrimMaxLen", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XTrimMaxLen indicates an expected call of XTrimMaxLen.
func (mr *MockUniversalClientMockRecorder) XTrimMaxLen(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XTrimMaxLen", reflect.TypeOf((*MockUniversalClient)(nil).XTrimMaxLen), arg0, arg1, arg2)
}
// XTrimMaxLenApprox mocks base method.
func (m *MockUniversalClient) XTrimMaxLenApprox(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XTrimMaxLenApprox", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XTrimMaxLenApprox indicates an expected call of XTrimMaxLenApprox.
func (mr *MockUniversalClientMockRecorder) XTrimMaxLenApprox(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XTrimMaxLenApprox", reflect.TypeOf((*MockUniversalClient)(nil).XTrimMaxLenApprox), arg0, arg1, arg2, arg3)
}
// XTrimMinID mocks base method.
func (m *MockUniversalClient) XTrimMinID(arg0 context.Context, arg1, arg2 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XTrimMinID", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XTrimMinID indicates an expected call of XTrimMinID.
func (mr *MockUniversalClientMockRecorder) XTrimMinID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XTrimMinID", reflect.TypeOf((*MockUniversalClient)(nil).XTrimMinID), arg0, arg1, arg2)
}
// XTrimMinIDApprox mocks base method.
func (m *MockUniversalClient) XTrimMinIDApprox(arg0 context.Context, arg1, arg2 string, arg3 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "XTrimMinIDApprox", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// XTrimMinIDApprox indicates an expected call of XTrimMinIDApprox.
func (mr *MockUniversalClientMockRecorder) XTrimMinIDApprox(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "XTrimMinIDApprox", reflect.TypeOf((*MockUniversalClient)(nil).XTrimMinIDApprox), arg0, arg1, arg2, arg3)
}
// ZAdd mocks base method.
func (m *MockUniversalClient) ZAdd(arg0 context.Context, arg1 string, arg2 ...*redis.Z) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZAdd", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZAdd indicates an expected call of ZAdd.
func (mr *MockUniversalClientMockRecorder) ZAdd(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZAdd", reflect.TypeOf((*MockUniversalClient)(nil).ZAdd), varargs...)
}
// ZAddArgs mocks base method.
func (m *MockUniversalClient) ZAddArgs(arg0 context.Context, arg1 string, arg2 redis.ZAddArgs) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZAddArgs", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZAddArgs indicates an expected call of ZAddArgs.
func (mr *MockUniversalClientMockRecorder) ZAddArgs(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZAddArgs", reflect.TypeOf((*MockUniversalClient)(nil).ZAddArgs), arg0, arg1, arg2)
}
// ZAddArgsIncr mocks base method.
func (m *MockUniversalClient) ZAddArgsIncr(arg0 context.Context, arg1 string, arg2 redis.ZAddArgs) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZAddArgsIncr", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// ZAddArgsIncr indicates an expected call of ZAddArgsIncr.
func (mr *MockUniversalClientMockRecorder) ZAddArgsIncr(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZAddArgsIncr", reflect.TypeOf((*MockUniversalClient)(nil).ZAddArgsIncr), arg0, arg1, arg2)
}
// ZAddCh mocks base method.
func (m *MockUniversalClient) ZAddCh(arg0 context.Context, arg1 string, arg2 ...*redis.Z) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZAddCh", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZAddCh indicates an expected call of ZAddCh.
func (mr *MockUniversalClientMockRecorder) ZAddCh(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZAddCh", reflect.TypeOf((*MockUniversalClient)(nil).ZAddCh), varargs...)
}
// ZAddNX mocks base method.
func (m *MockUniversalClient) ZAddNX(arg0 context.Context, arg1 string, arg2 ...*redis.Z) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZAddNX", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZAddNX indicates an expected call of ZAddNX.
func (mr *MockUniversalClientMockRecorder) ZAddNX(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZAddNX", reflect.TypeOf((*MockUniversalClient)(nil).ZAddNX), varargs...)
}
// ZAddNXCh mocks base method.
func (m *MockUniversalClient) ZAddNXCh(arg0 context.Context, arg1 string, arg2 ...*redis.Z) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZAddNXCh", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZAddNXCh indicates an expected call of ZAddNXCh.
func (mr *MockUniversalClientMockRecorder) ZAddNXCh(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZAddNXCh", reflect.TypeOf((*MockUniversalClient)(nil).ZAddNXCh), varargs...)
}
// ZAddXX mocks base method.
func (m *MockUniversalClient) ZAddXX(arg0 context.Context, arg1 string, arg2 ...*redis.Z) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZAddXX", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZAddXX indicates an expected call of ZAddXX.
func (mr *MockUniversalClientMockRecorder) ZAddXX(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZAddXX", reflect.TypeOf((*MockUniversalClient)(nil).ZAddXX), varargs...)
}
// ZAddXXCh mocks base method.
func (m *MockUniversalClient) ZAddXXCh(arg0 context.Context, arg1 string, arg2 ...*redis.Z) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZAddXXCh", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZAddXXCh indicates an expected call of ZAddXXCh.
func (mr *MockUniversalClientMockRecorder) ZAddXXCh(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZAddXXCh", reflect.TypeOf((*MockUniversalClient)(nil).ZAddXXCh), varargs...)
}
// ZCard mocks base method.
func (m *MockUniversalClient) ZCard(arg0 context.Context, arg1 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZCard", arg0, arg1)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZCard indicates an expected call of ZCard.
func (mr *MockUniversalClientMockRecorder) ZCard(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZCard", reflect.TypeOf((*MockUniversalClient)(nil).ZCard), arg0, arg1)
}
// ZCount mocks base method.
func (m *MockUniversalClient) ZCount(arg0 context.Context, arg1, arg2, arg3 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZCount", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZCount indicates an expected call of ZCount.
func (mr *MockUniversalClientMockRecorder) ZCount(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZCount", reflect.TypeOf((*MockUniversalClient)(nil).ZCount), arg0, arg1, arg2, arg3)
}
// ZDiff mocks base method.
func (m *MockUniversalClient) ZDiff(arg0 context.Context, arg1 ...string) *redis.StringSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZDiff", varargs...)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZDiff indicates an expected call of ZDiff.
func (mr *MockUniversalClientMockRecorder) ZDiff(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZDiff", reflect.TypeOf((*MockUniversalClient)(nil).ZDiff), varargs...)
}
// ZDiffStore mocks base method.
func (m *MockUniversalClient) ZDiffStore(arg0 context.Context, arg1 string, arg2 ...string) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZDiffStore", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZDiffStore indicates an expected call of ZDiffStore.
func (mr *MockUniversalClientMockRecorder) ZDiffStore(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZDiffStore", reflect.TypeOf((*MockUniversalClient)(nil).ZDiffStore), varargs...)
}
// ZDiffWithScores mocks base method.
func (m *MockUniversalClient) ZDiffWithScores(arg0 context.Context, arg1 ...string) *redis.ZSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZDiffWithScores", varargs...)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZDiffWithScores indicates an expected call of ZDiffWithScores.
func (mr *MockUniversalClientMockRecorder) ZDiffWithScores(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZDiffWithScores", reflect.TypeOf((*MockUniversalClient)(nil).ZDiffWithScores), varargs...)
}
// ZIncr mocks base method.
func (m *MockUniversalClient) ZIncr(arg0 context.Context, arg1 string, arg2 *redis.Z) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZIncr", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// ZIncr indicates an expected call of ZIncr.
func (mr *MockUniversalClientMockRecorder) ZIncr(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZIncr", reflect.TypeOf((*MockUniversalClient)(nil).ZIncr), arg0, arg1, arg2)
}
// ZIncrBy mocks base method.
func (m *MockUniversalClient) ZIncrBy(arg0 context.Context, arg1 string, arg2 float64, arg3 string) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZIncrBy", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// ZIncrBy indicates an expected call of ZIncrBy.
func (mr *MockUniversalClientMockRecorder) ZIncrBy(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZIncrBy", reflect.TypeOf((*MockUniversalClient)(nil).ZIncrBy), arg0, arg1, arg2, arg3)
}
// ZIncrNX mocks base method.
func (m *MockUniversalClient) ZIncrNX(arg0 context.Context, arg1 string, arg2 *redis.Z) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZIncrNX", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// ZIncrNX indicates an expected call of ZIncrNX.
func (mr *MockUniversalClientMockRecorder) ZIncrNX(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZIncrNX", reflect.TypeOf((*MockUniversalClient)(nil).ZIncrNX), arg0, arg1, arg2)
}
// ZIncrXX mocks base method.
func (m *MockUniversalClient) ZIncrXX(arg0 context.Context, arg1 string, arg2 *redis.Z) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZIncrXX", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// ZIncrXX indicates an expected call of ZIncrXX.
func (mr *MockUniversalClientMockRecorder) ZIncrXX(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZIncrXX", reflect.TypeOf((*MockUniversalClient)(nil).ZIncrXX), arg0, arg1, arg2)
}
// ZInter mocks base method.
func (m *MockUniversalClient) ZInter(arg0 context.Context, arg1 *redis.ZStore) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZInter", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZInter indicates an expected call of ZInter.
func (mr *MockUniversalClientMockRecorder) ZInter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZInter", reflect.TypeOf((*MockUniversalClient)(nil).ZInter), arg0, arg1)
}
// ZInterStore mocks base method.
func (m *MockUniversalClient) ZInterStore(arg0 context.Context, arg1 string, arg2 *redis.ZStore) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZInterStore", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZInterStore indicates an expected call of ZInterStore.
func (mr *MockUniversalClientMockRecorder) ZInterStore(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZInterStore", reflect.TypeOf((*MockUniversalClient)(nil).ZInterStore), arg0, arg1, arg2)
}
// ZInterWithScores mocks base method.
func (m *MockUniversalClient) ZInterWithScores(arg0 context.Context, arg1 *redis.ZStore) *redis.ZSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZInterWithScores", arg0, arg1)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZInterWithScores indicates an expected call of ZInterWithScores.
func (mr *MockUniversalClientMockRecorder) ZInterWithScores(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZInterWithScores", reflect.TypeOf((*MockUniversalClient)(nil).ZInterWithScores), arg0, arg1)
}
// ZLexCount mocks base method.
func (m *MockUniversalClient) ZLexCount(arg0 context.Context, arg1, arg2, arg3 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZLexCount", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZLexCount indicates an expected call of ZLexCount.
func (mr *MockUniversalClientMockRecorder) ZLexCount(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZLexCount", reflect.TypeOf((*MockUniversalClient)(nil).ZLexCount), arg0, arg1, arg2, arg3)
}
// ZMScore mocks base method.
func (m *MockUniversalClient) ZMScore(arg0 context.Context, arg1 string, arg2 ...string) *redis.FloatSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZMScore", varargs...)
ret0, _ := ret[0].(*redis.FloatSliceCmd)
return ret0
}
// ZMScore indicates an expected call of ZMScore.
func (mr *MockUniversalClientMockRecorder) ZMScore(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZMScore", reflect.TypeOf((*MockUniversalClient)(nil).ZMScore), varargs...)
}
// ZPopMax mocks base method.
func (m *MockUniversalClient) ZPopMax(arg0 context.Context, arg1 string, arg2 ...int64) *redis.ZSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZPopMax", varargs...)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZPopMax indicates an expected call of ZPopMax.
func (mr *MockUniversalClientMockRecorder) ZPopMax(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZPopMax", reflect.TypeOf((*MockUniversalClient)(nil).ZPopMax), varargs...)
}
// ZPopMin mocks base method.
func (m *MockUniversalClient) ZPopMin(arg0 context.Context, arg1 string, arg2 ...int64) *redis.ZSliceCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZPopMin", varargs...)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZPopMin indicates an expected call of ZPopMin.
func (mr *MockUniversalClientMockRecorder) ZPopMin(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZPopMin", reflect.TypeOf((*MockUniversalClient)(nil).ZPopMin), varargs...)
}
// ZRandMember mocks base method.
func (m *MockUniversalClient) ZRandMember(arg0 context.Context, arg1 string, arg2 int, arg3 bool) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRandMember", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZRandMember indicates an expected call of ZRandMember.
func (mr *MockUniversalClientMockRecorder) ZRandMember(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRandMember", reflect.TypeOf((*MockUniversalClient)(nil).ZRandMember), arg0, arg1, arg2, arg3)
}
// ZRange mocks base method.
func (m *MockUniversalClient) ZRange(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRange", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZRange indicates an expected call of ZRange.
func (mr *MockUniversalClientMockRecorder) ZRange(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRange", reflect.TypeOf((*MockUniversalClient)(nil).ZRange), arg0, arg1, arg2, arg3)
}
// ZRangeArgs mocks base method.
func (m *MockUniversalClient) ZRangeArgs(arg0 context.Context, arg1 redis.ZRangeArgs) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRangeArgs", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZRangeArgs indicates an expected call of ZRangeArgs.
func (mr *MockUniversalClientMockRecorder) ZRangeArgs(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRangeArgs", reflect.TypeOf((*MockUniversalClient)(nil).ZRangeArgs), arg0, arg1)
}
// ZRangeArgsWithScores mocks base method.
func (m *MockUniversalClient) ZRangeArgsWithScores(arg0 context.Context, arg1 redis.ZRangeArgs) *redis.ZSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRangeArgsWithScores", arg0, arg1)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZRangeArgsWithScores indicates an expected call of ZRangeArgsWithScores.
func (mr *MockUniversalClientMockRecorder) ZRangeArgsWithScores(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRangeArgsWithScores", reflect.TypeOf((*MockUniversalClient)(nil).ZRangeArgsWithScores), arg0, arg1)
}
// ZRangeByLex mocks base method.
func (m *MockUniversalClient) ZRangeByLex(arg0 context.Context, arg1 string, arg2 *redis.ZRangeBy) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRangeByLex", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZRangeByLex indicates an expected call of ZRangeByLex.
func (mr *MockUniversalClientMockRecorder) ZRangeByLex(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRangeByLex", reflect.TypeOf((*MockUniversalClient)(nil).ZRangeByLex), arg0, arg1, arg2)
}
// ZRangeByScore mocks base method.
func (m *MockUniversalClient) ZRangeByScore(arg0 context.Context, arg1 string, arg2 *redis.ZRangeBy) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRangeByScore", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZRangeByScore indicates an expected call of ZRangeByScore.
func (mr *MockUniversalClientMockRecorder) ZRangeByScore(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRangeByScore", reflect.TypeOf((*MockUniversalClient)(nil).ZRangeByScore), arg0, arg1, arg2)
}
// ZRangeByScoreWithScores mocks base method.
func (m *MockUniversalClient) ZRangeByScoreWithScores(arg0 context.Context, arg1 string, arg2 *redis.ZRangeBy) *redis.ZSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRangeByScoreWithScores", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZRangeByScoreWithScores indicates an expected call of ZRangeByScoreWithScores.
func (mr *MockUniversalClientMockRecorder) ZRangeByScoreWithScores(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRangeByScoreWithScores", reflect.TypeOf((*MockUniversalClient)(nil).ZRangeByScoreWithScores), arg0, arg1, arg2)
}
// ZRangeStore mocks base method.
func (m *MockUniversalClient) ZRangeStore(arg0 context.Context, arg1 string, arg2 redis.ZRangeArgs) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRangeStore", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZRangeStore indicates an expected call of ZRangeStore.
func (mr *MockUniversalClientMockRecorder) ZRangeStore(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRangeStore", reflect.TypeOf((*MockUniversalClient)(nil).ZRangeStore), arg0, arg1, arg2)
}
// ZRangeWithScores mocks base method.
func (m *MockUniversalClient) ZRangeWithScores(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.ZSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRangeWithScores", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZRangeWithScores indicates an expected call of ZRangeWithScores.
func (mr *MockUniversalClientMockRecorder) ZRangeWithScores(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRangeWithScores", reflect.TypeOf((*MockUniversalClient)(nil).ZRangeWithScores), arg0, arg1, arg2, arg3)
}
// ZRank mocks base method.
func (m *MockUniversalClient) ZRank(arg0 context.Context, arg1, arg2 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRank", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZRank indicates an expected call of ZRank.
func (mr *MockUniversalClientMockRecorder) ZRank(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRank", reflect.TypeOf((*MockUniversalClient)(nil).ZRank), arg0, arg1, arg2)
}
// ZRem mocks base method.
func (m *MockUniversalClient) ZRem(arg0 context.Context, arg1 string, arg2 ...interface{}) *redis.IntCmd {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "ZRem", varargs...)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZRem indicates an expected call of ZRem.
func (mr *MockUniversalClientMockRecorder) ZRem(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRem", reflect.TypeOf((*MockUniversalClient)(nil).ZRem), varargs...)
}
// ZRemRangeByLex mocks base method.
func (m *MockUniversalClient) ZRemRangeByLex(arg0 context.Context, arg1, arg2, arg3 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRemRangeByLex", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZRemRangeByLex indicates an expected call of ZRemRangeByLex.
func (mr *MockUniversalClientMockRecorder) ZRemRangeByLex(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRemRangeByLex", reflect.TypeOf((*MockUniversalClient)(nil).ZRemRangeByLex), arg0, arg1, arg2, arg3)
}
// ZRemRangeByRank mocks base method.
func (m *MockUniversalClient) ZRemRangeByRank(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRemRangeByRank", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZRemRangeByRank indicates an expected call of ZRemRangeByRank.
func (mr *MockUniversalClientMockRecorder) ZRemRangeByRank(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRemRangeByRank", reflect.TypeOf((*MockUniversalClient)(nil).ZRemRangeByRank), arg0, arg1, arg2, arg3)
}
// ZRemRangeByScore mocks base method.
func (m *MockUniversalClient) ZRemRangeByScore(arg0 context.Context, arg1, arg2, arg3 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRemRangeByScore", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZRemRangeByScore indicates an expected call of ZRemRangeByScore.
func (mr *MockUniversalClientMockRecorder) ZRemRangeByScore(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRemRangeByScore", reflect.TypeOf((*MockUniversalClient)(nil).ZRemRangeByScore), arg0, arg1, arg2, arg3)
}
// ZRevRange mocks base method.
func (m *MockUniversalClient) ZRevRange(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRevRange", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZRevRange indicates an expected call of ZRevRange.
func (mr *MockUniversalClientMockRecorder) ZRevRange(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRevRange", reflect.TypeOf((*MockUniversalClient)(nil).ZRevRange), arg0, arg1, arg2, arg3)
}
// ZRevRangeByLex mocks base method.
func (m *MockUniversalClient) ZRevRangeByLex(arg0 context.Context, arg1 string, arg2 *redis.ZRangeBy) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRevRangeByLex", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZRevRangeByLex indicates an expected call of ZRevRangeByLex.
func (mr *MockUniversalClientMockRecorder) ZRevRangeByLex(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRevRangeByLex", reflect.TypeOf((*MockUniversalClient)(nil).ZRevRangeByLex), arg0, arg1, arg2)
}
// ZRevRangeByScore mocks base method.
func (m *MockUniversalClient) ZRevRangeByScore(arg0 context.Context, arg1 string, arg2 *redis.ZRangeBy) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRevRangeByScore", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZRevRangeByScore indicates an expected call of ZRevRangeByScore.
func (mr *MockUniversalClientMockRecorder) ZRevRangeByScore(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRevRangeByScore", reflect.TypeOf((*MockUniversalClient)(nil).ZRevRangeByScore), arg0, arg1, arg2)
}
// ZRevRangeByScoreWithScores mocks base method.
func (m *MockUniversalClient) ZRevRangeByScoreWithScores(arg0 context.Context, arg1 string, arg2 *redis.ZRangeBy) *redis.ZSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRevRangeByScoreWithScores", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZRevRangeByScoreWithScores indicates an expected call of ZRevRangeByScoreWithScores.
func (mr *MockUniversalClientMockRecorder) ZRevRangeByScoreWithScores(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRevRangeByScoreWithScores", reflect.TypeOf((*MockUniversalClient)(nil).ZRevRangeByScoreWithScores), arg0, arg1, arg2)
}
// ZRevRangeWithScores mocks base method.
func (m *MockUniversalClient) ZRevRangeWithScores(arg0 context.Context, arg1 string, arg2, arg3 int64) *redis.ZSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRevRangeWithScores", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZRevRangeWithScores indicates an expected call of ZRevRangeWithScores.
func (mr *MockUniversalClientMockRecorder) ZRevRangeWithScores(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRevRangeWithScores", reflect.TypeOf((*MockUniversalClient)(nil).ZRevRangeWithScores), arg0, arg1, arg2, arg3)
}
// ZRevRank mocks base method.
func (m *MockUniversalClient) ZRevRank(arg0 context.Context, arg1, arg2 string) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZRevRank", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZRevRank indicates an expected call of ZRevRank.
func (mr *MockUniversalClientMockRecorder) ZRevRank(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZRevRank", reflect.TypeOf((*MockUniversalClient)(nil).ZRevRank), arg0, arg1, arg2)
}
// ZScan mocks base method.
func (m *MockUniversalClient) ZScan(arg0 context.Context, arg1 string, arg2 uint64, arg3 string, arg4 int64) *redis.ScanCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZScan", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*redis.ScanCmd)
return ret0
}
// ZScan indicates an expected call of ZScan.
func (mr *MockUniversalClientMockRecorder) ZScan(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZScan", reflect.TypeOf((*MockUniversalClient)(nil).ZScan), arg0, arg1, arg2, arg3, arg4)
}
// ZScore mocks base method.
func (m *MockUniversalClient) ZScore(arg0 context.Context, arg1, arg2 string) *redis.FloatCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZScore", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.FloatCmd)
return ret0
}
// ZScore indicates an expected call of ZScore.
func (mr *MockUniversalClientMockRecorder) ZScore(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZScore", reflect.TypeOf((*MockUniversalClient)(nil).ZScore), arg0, arg1, arg2)
}
// ZUnion mocks base method.
func (m *MockUniversalClient) ZUnion(arg0 context.Context, arg1 redis.ZStore) *redis.StringSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZUnion", arg0, arg1)
ret0, _ := ret[0].(*redis.StringSliceCmd)
return ret0
}
// ZUnion indicates an expected call of ZUnion.
func (mr *MockUniversalClientMockRecorder) ZUnion(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZUnion", reflect.TypeOf((*MockUniversalClient)(nil).ZUnion), arg0, arg1)
}
// ZUnionStore mocks base method.
func (m *MockUniversalClient) ZUnionStore(arg0 context.Context, arg1 string, arg2 *redis.ZStore) *redis.IntCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZUnionStore", arg0, arg1, arg2)
ret0, _ := ret[0].(*redis.IntCmd)
return ret0
}
// ZUnionStore indicates an expected call of ZUnionStore.
func (mr *MockUniversalClientMockRecorder) ZUnionStore(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZUnionStore", reflect.TypeOf((*MockUniversalClient)(nil).ZUnionStore), arg0, arg1, arg2)
}
// ZUnionWithScores mocks base method.
func (m *MockUniversalClient) ZUnionWithScores(arg0 context.Context, arg1 redis.ZStore) *redis.ZSliceCmd {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ZUnionWithScores", arg0, arg1)
ret0, _ := ret[0].(*redis.ZSliceCmd)
return ret0
}
// ZUnionWithScores indicates an expected call of ZUnionWithScores.
func (mr *MockUniversalClientMockRecorder) ZUnionWithScores(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ZUnionWithScores", reflect.TypeOf((*MockUniversalClient)(nil).ZUnionWithScores), arg0, arg1)
}
// Code generated by MockGen. DO NOT EDIT.
// Source: ../../pkg/security/session/store.go
// Package mock_session is a generated GoMock package.
package sessionmock
import (
context "context"
session "github.com/cisco-open/go-lanai/pkg/security/session"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockStore is a mock of Store interface
type MockStore struct {
ctrl *gomock.Controller
recorder *MockStoreMockRecorder
}
// MockStoreMockRecorder is the mock recorder for MockStore
type MockStoreMockRecorder struct {
mock *MockStore
}
// NewMockStore creates a new mock instance
func NewMockStore(ctrl *gomock.Controller) *MockStore {
mock := &MockStore{ctrl: ctrl}
mock.recorder = &MockStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
return m.recorder
}
// Get mocks base method
func (m *MockStore) Get(id, name string) (*session.Session, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", id, name)
ret0, _ := ret[0].(*session.Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get
func (mr *MockStoreMockRecorder) Get(id, name interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStore)(nil).Get), id, name)
}
// New mocks base method
func (m *MockStore) New(name string) (*session.Session, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "New", name)
ret0, _ := ret[0].(*session.Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// New indicates an expected call of New
func (mr *MockStoreMockRecorder) New(name interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockStore)(nil).New), name)
}
// Save mocks base method
func (m *MockStore) Save(s *session.Session) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Save", s)
ret0, _ := ret[0].(error)
return ret0
}
// Save indicates an expected call of Save
func (mr *MockStoreMockRecorder) Save(s interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Save", reflect.TypeOf((*MockStore)(nil).Save), s)
}
// Invalidate mocks base method
func (m *MockStore) Invalidate(sessions ...*session.Session) error {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range sessions {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Invalidate", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Invalidate indicates an expected call of Invalidate
func (mr *MockStoreMockRecorder) Invalidate(sessions ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Invalidate", reflect.TypeOf((*MockStore)(nil).Invalidate), sessions...)
}
// Options mocks base method
func (m *MockStore) Options() *session.Options {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Options")
ret0, _ := ret[0].(*session.Options)
return ret0
}
// Options indicates an expected call of Options
func (mr *MockStoreMockRecorder) Options() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Options", reflect.TypeOf((*MockStore)(nil).Options))
}
// ChangeId mocks base method
func (m *MockStore) ChangeId(s *session.Session) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ChangeId", s)
ret0, _ := ret[0].(error)
return ret0
}
// ChangeId indicates an expected call of ChangeId
func (mr *MockStoreMockRecorder) ChangeId(s interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeId", reflect.TypeOf((*MockStore)(nil).ChangeId), s)
}
// AddToPrincipalIndex mocks base method
func (m *MockStore) AddToPrincipalIndex(principal string, session *session.Session) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddToPrincipalIndex", principal, session)
ret0, _ := ret[0].(error)
return ret0
}
// AddToPrincipalIndex indicates an expected call of AddToPrincipalIndex
func (mr *MockStoreMockRecorder) AddToPrincipalIndex(principal, session interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddToPrincipalIndex", reflect.TypeOf((*MockStore)(nil).AddToPrincipalIndex), principal, session)
}
// RemoveFromPrincipalIndex mocks base method
func (m *MockStore) RemoveFromPrincipalIndex(principal string, sessions *session.Session) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveFromPrincipalIndex", principal, sessions)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveFromPrincipalIndex indicates an expected call of RemoveFromPrincipalIndex
func (mr *MockStoreMockRecorder) RemoveFromPrincipalIndex(principal, sessions interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveFromPrincipalIndex", reflect.TypeOf((*MockStore)(nil).RemoveFromPrincipalIndex), principal, sessions)
}
// FindByPrincipalName mocks base method
func (m *MockStore) FindByPrincipalName(principal, sessionName string) ([]*session.Session, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindByPrincipalName", principal, sessionName)
ret0, _ := ret[0].([]*session.Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindByPrincipalName indicates an expected call of FindByPrincipalName
func (mr *MockStoreMockRecorder) FindByPrincipalName(principal, sessionName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByPrincipalName", reflect.TypeOf((*MockStore)(nil).FindByPrincipalName), principal, sessionName)
}
// InvalidateByPrincipalName mocks base method
func (m *MockStore) InvalidateByPrincipalName(principal, sessionName string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InvalidateByPrincipalName", principal, sessionName)
ret0, _ := ret[0].(error)
return ret0
}
// InvalidateByPrincipalName indicates an expected call of InvalidateByPrincipalName
func (mr *MockStoreMockRecorder) InvalidateByPrincipalName(principal, sessionName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InvalidateByPrincipalName", reflect.TypeOf((*MockStore)(nil).InvalidateByPrincipalName), principal, sessionName)
}
// WithContext mocks base method
func (m *MockStore) WithContext(ctx context.Context) session.Store {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WithContext", ctx)
ret0, _ := ret[0].(session.Store)
return ret0
}
// WithContext indicates an expected call of WithContext
func (mr *MockStoreMockRecorder) WithContext(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithContext", reflect.TypeOf((*MockStore)(nil).WithContext), ctx)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearchtest
import (
"fmt"
"github.com/cisco-open/go-lanai/test/ittest"
"strings"
)
// BulkJsonBodyMatcher special body matcher for OpenSearch's _bulk API
// See https://opensearch.org/docs/2.11/api-reference/document-apis/bulk/
type BulkJsonBodyMatcher struct {
Delegate ittest.RecordBodyMatcher
}
func (m BulkJsonBodyMatcher) Support(contentType string) bool {
return m.Delegate.Support(contentType)
}
func (m BulkJsonBodyMatcher) Matches(out []byte, record []byte) error {
outSplit := m.split(out)
recordSplit := m.split(record)
if len(outSplit) != len(recordSplit) {
return fmt.Errorf(`mismatched number of JSON objects: expect %d but got %d`, len(recordSplit), len(outSplit))
}
for i := range outSplit {
if e := m.Delegate.Matches(outSplit[i], recordSplit[i]); e != nil {
return fmt.Errorf("mismatched JSON object at index %d: %v", i, e)
}
}
return nil
}
func (m BulkJsonBodyMatcher) split(data []byte) [][]byte {
split := strings.Split(string(data), "\n")
rs := make([][]byte, 0, len(split))
for i := range split {
trimmed := strings.TrimSpace(split[i])
if len(trimmed) != 0 {
rs = append(rs, []byte(trimmed))
}
}
return rs
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearchtest
import (
"context"
"github.com/cisco-open/go-lanai/pkg/opensearch"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"github.com/cockroachdb/copyist"
opensearchgo "github.com/opensearch-project/opensearch-go"
"github.com/opensearch-project/opensearch-go/opensearchapi"
"github.com/opensearch-project/opensearch-go/opensearchutil"
"go.uber.org/fx"
"gopkg.in/dnaeon/go-vcr.v3/recorder"
"net/http"
"testing"
"time"
)
// IndexSuffix is the suffix we append to the index name when running opensearch tests, so that we don't
// corrupt the application's indices.
const IndexSuffix = "_test"
// IsRecording returns true if copyist is currently in recording mode.
// We wrap the copyist.IsRecording because we re-use the same commandline flag
// as the copyist library, and flag.Bool doesn't like it when you have two places
// that listen to the same flag
func IsRecording() bool {
return copyist.IsRecording()
}
type Options func(opt *Option)
type Option struct {
Name string
SavePath string
Mode Mode
RealTransport http.RoundTripper
SkipRequestLatency bool
FuzzyJsonPaths []string
RecordDelay time.Duration
}
// SetRecordDelay add delay between each request.
// Note: original request latency is applied by default. This is the additional delay between each requests
func SetRecordDelay(delay time.Duration) Options {
return func(opt *Option) {
SkipRequestLatency(false)(opt)
opt.RecordDelay = delay
}
}
func SetRecordMode(mode Mode) Options {
return func(o *Option) {
o.Mode = mode
}
}
// SkipRequestLatency disable mimic request latency in playback mode. Has no effect during recording
// By default, original request latency during recording is applied in playback mode.
func SkipRequestLatency(skip bool) Options {
return func(o *Option) {
o.SkipRequestLatency = skip
}
}
// ReplayMode override recording/playback mode. Default is ModeCommandline
func ReplayMode(mode Mode) Options {
return func(o *Option) {
o.Mode = mode
}
}
// FuzzyJsonPaths ignore part of JSON body with JSONPath notation during playback mode.
// Useful for search queries with time-sensitive fields
// e.g. FuzzyJsonPaths("$.query.*.Time")
// JSONPath Syntax: https://goessner.net/articles/JsonPath/
func FuzzyJsonPaths(jsonPaths...string) Options {
return func(o *Option) {
o.FuzzyJsonPaths = append(o.FuzzyJsonPaths, jsonPaths...)
}
}
// WithOpenSearchPlayback will setup the recorder, similar to crdb's copyist functionality
// where actual interactions with opensearch will be recorded, and then when the mode is set to
// ModeReplaying, the recorder will respond with its recorded responses.
// the parameter recordDelay defines how long of a delay is needed between a write to
// opensearch, and a read. opensearch does not immediately have writes available, so the only
// solution right now is to delay and reads that happen immediately after a write.
// For some reason, the refresh options on the index to opensearch are not working.
//
// To control what is being matched in the http vcr, this function will provide a
// *MatcherBodyModifiers to uber.FX.
func WithOpenSearchPlayback(options ...Options) test.Options {
openSearchOption := Option {
Mode: ModeCommandline,
}
for _, fn := range options {
fn(&openSearchOption)
}
//var modifiers MatcherBodyModifiers
//openSearchOption.RecordOptions = append(
// openSearchOption.RecordOptions,
// func(c *RecordOption) {
// c.Modifiers = &modifiers
// },
//)
var rec *recorder.Recorder
testOpts := []test.Options{
test.Setup(
startRecording(&rec, options...),
),
apptest.WithFxOptions(
fx.Decorate(func(c opensearchgo.Config) opensearchgo.Config {
c.Transport = rec
return c
}),
fx.Provide(
IndexEditHookProvider(opensearch.FxGroup),
//func() *MatcherBodyModifiers { return &MatcherBodyModifiers{} },
),
),
test.Teardown(stopRecording()),
}
if openSearchOption.Mode == ModeRecording || openSearchOption.Mode == ModeCommandline && IsRecording(){
testOpts = append(testOpts, apptest.WithFxOptions(
fx.Provide(SearchDelayerHookProvider(opensearch.FxGroup, openSearchOption.RecordDelay)),
))
}
return test.WithOptions(testOpts...)
}
func startRecording(recRef **recorder.Recorder, options ...Options) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
initial := func(c *Option) {
c.Mode = ModeCommandline
c.Name = t.Name()
c.SavePath = "testdata"
}
opts := append([]Options{initial}, options...)
var err error
*recRef, err = NewRecorder(opts...)
return contextWithRecorder(ctx, *recRef), err
}
}
func stopRecording() test.TeardownFunc {
return func(ctx context.Context, t *testing.T) error {
if rec, ok := ctx.Value(ckRecorder{}).(*recorder.Recorder); ok {
return rec.Stop()
}
return nil
}
}
// SearchDelayer will ensure that all searches that happen after inserting a document
// will have a delay so that the search can find all the documents.
type SearchDelayer struct {
Delay time.Duration
lastEvent opensearch.CommandType
}
func (s *SearchDelayer) Before(ctx context.Context, beforeContext opensearch.BeforeContext) context.Context {
if beforeContext.CommandType() == opensearch.CmdSearch && s.lastEvent == opensearch.CmdIndex {
time.Sleep(s.Delay)
}
return ctx
}
func (s *SearchDelayer) After(ctx context.Context, afterContext opensearch.AfterContext) context.Context {
s.lastEvent = afterContext.CommandType()
return ctx
}
func SearchDelayerHook(delay time.Duration) *SearchDelayer {
return &SearchDelayer{Delay: delay}
}
func SearchDelayerHookProvider(group string, delay time.Duration) (fx.Annotated, fx.Annotated) {
searchDelayer := SearchDelayerHook(delay)
return fx.Annotated{
Group: group, Target: func() opensearch.BeforeHook { return searchDelayer },
},
fx.Annotated{
Group: group, Target: func() opensearch.AfterHook { return searchDelayer },
}
}
type EditIndexForTestingHook struct {
Suffix string
}
func (e *EditIndexForTestingHook) Order() int {
return order.Highest
}
func NewEditingIndexForTestingHook() opensearch.BeforeHook {
return &EditIndexForTestingHook{
Suffix: IndexSuffix,
}
}
func (e *EditIndexForTestingHook) Before(ctx context.Context, before opensearch.BeforeContext) context.Context {
switch opt := before.Options.(type) {
case *[]func(request *opensearchapi.SearchRequest):
f := func(request *opensearchapi.SearchRequest) {
var indices []string
for _, index := range request.Index {
indices = append(indices, index+e.Suffix)
}
request.Index = indices
}
*opt = append(*opt, f)
case *[]func(request *opensearchapi.IndicesCreateRequest):
f := func(request *opensearchapi.IndicesCreateRequest) {
request.Index = request.Index + e.Suffix
}
*opt = append(*opt, f)
case *[]func(request *opensearchapi.IndexRequest):
f := func(request *opensearchapi.IndexRequest) {
request.Index = request.Index + e.Suffix
}
*opt = append(*opt, f)
case *[]func(request *opensearchapi.IndicesPutAliasRequest):
f := func(request *opensearchapi.IndicesPutAliasRequest) {
var indices []string
for _, index := range request.Index {
indices = append(indices, index+e.Suffix)
}
request.Index = indices
}
*opt = append(*opt, f)
case *[]func(request *opensearchapi.IndicesDeleteAliasRequest):
f := func(request *opensearchapi.IndicesDeleteAliasRequest) {
var indices []string
for _, index := range request.Index {
indices = append(indices, index+e.Suffix)
}
request.Index = indices
}
*opt = append(*opt, f)
case *[]func(request *opensearchapi.IndicesDeleteRequest):
f := func(request *opensearchapi.IndicesDeleteRequest) {
var indices []string
for _, index := range request.Index {
indices = append(indices, index+e.Suffix)
}
request.Index = indices
}
*opt = append(*opt, f)
case *[]func(cfg *opensearchutil.BulkIndexerConfig):
f := func(cfg *opensearchutil.BulkIndexerConfig) {
cfg.Index = cfg.Index + e.Suffix
}
*opt = append(*opt, f)
case *[]func(request *opensearchapi.SearchTemplateRequest):
f := func(request *opensearchapi.SearchTemplateRequest) {
var indices []string
for _, index := range request.Index {
indices = append(indices, index+e.Suffix)
}
request.Index = indices
}
*opt = append(*opt, f)
}
return ctx
}
func IndexEditHookProvider(group string) fx.Annotated {
return fx.Annotated{
Group: group,
Target: NewEditingIndexForTestingHook,
}
}
/******************
Context
******************/
type ckRecorder struct{}
type recorderContext struct {
context.Context
rec *recorder.Recorder
}
func (c recorderContext) Value(k interface{}) interface{} {
switch k {
case ckRecorder{}:
return c.rec
default:
return c.Context.Value(k)
}
}
func contextWithRecorder(parent context.Context, rec *recorder.Recorder) context.Context {
return recorderContext{
Context: parent,
rec: rec,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package opensearchtest
import (
"fmt"
"github.com/cisco-open/go-lanai/test/ittest"
"github.com/pkg/errors"
"gopkg.in/dnaeon/go-vcr.v3/recorder"
"runtime"
"strings"
)
var (
ErrCreatingRecorder = errors.New("unable to create recorder")
ErrNoCassetteName = errors.New("requires cassette name")
)
type Mode recorder.Mode
// Recorder states
const (
ModeRecording Mode = iota
ModeReplaying
// ModeCommandline lets the commandline or the state in TestMain to determine the mode
ModeCommandline
)
// NewRecorder will create a recorder configured by the RecordOptions
func NewRecorder(options ...Options) (*recorder.Recorder, error) {
var recordOption Option
for _, fn := range options {
fn(&recordOption)
}
if recordOption.Name == "" {
return nil, ErrNoCassetteName
}
rec, e := ittest.NewHttpRecorder(ittest.HttpRecordOrdering(false), toHTTPVCROptions(recordOption))
if e != nil {
return nil, fmt.Errorf("%w, %v", ErrCreatingRecorder, e)
}
return rec.Recorder, nil
}
// findTestFile - copied from copyist.go - Searches the call stack, looking for the test that called
// copyist.Open. It searches up to N levels, looking for the last file that
// ends in "_test.go" and returns that filename.
func findTestFile() string {
const levels = 10
var lastTestFilename string
for i := 0; i < levels; i++ {
_, fileName, _, _ := runtime.Caller(2 + i)
if strings.HasSuffix(fileName, "_test.go") {
lastTestFilename = fileName
}
}
if lastTestFilename != "" {
return lastTestFilename
}
panic(fmt.Errorf("open was not called directly or indirectly from a test file"))
}
func toHTTPVCROptions(opt Option) ittest.HTTPVCROptions {
return func(vcrOpt *ittest.HTTPVCROption) {
vcrOpt.Mode = ittest.ModeReplaying
switch opt.Mode {
case ModeRecording:
vcrOpt.Mode = ittest.ModeRecording
case ModeCommandline:
if IsRecording() {
vcrOpt.Mode = ittest.ModeRecording
}
default:
}
vcrOpt.Name = opt.Name
vcrOpt.SavePath = opt.SavePath
vcrOpt.RecordMatching = append(vcrOpt.RecordMatching, func(matcherOpt *ittest.RecordMatcherOption) {
matcherOpt.BodyMatchers = append(matcherOpt.BodyMatchers, BulkJsonBodyMatcher{
Delegate: ittest.NewRecordJsonBodyMatcher(opt.FuzzyJsonPaths...),
})
matcherOpt.FuzzyHeaders = append(matcherOpt.FuzzyHeaders, "User-Agent")
})
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"bytes"
"compress/flate"
"encoding/base64"
"encoding/xml"
"fmt"
"github.com/beevik/etree"
samlutils "github.com/cisco-open/go-lanai/pkg/security/saml/utils"
"github.com/cisco-open/go-lanai/test/webtest"
"github.com/crewjam/saml"
"io"
"net/http"
"net/url"
"strings"
)
type BindableSamlTypes interface {
saml.LogoutRequest | saml.LogoutResponse | saml.AuthnRequest | saml.Response
}
// xmlSerializer is a common interface of saml.LogoutRequest, saml.LogoutResponse, saml.AuthnRequest, saml.Response
type xmlSerializer interface {
Element() *etree.Element
}
// RequestWithSAMLPostBinding returns a webtest.RequestOptions that inject given SAML Request/Response using Post binding.
// Note: request need to be POST
func RequestWithSAMLPostBinding[T BindableSamlTypes](samlObj *T, relayState string) webtest.RequestOptions {
return func(req *http.Request) {
var i interface{} = samlObj
serializer, ok := i.(xmlSerializer)
if !ok {
panic(fmt.Sprintf("%T doess not have Element() *etree.Element", i))
}
doc := etree.NewDocument()
doc.SetRoot(serializer.Element())
decoded, e := doc.WriteToBytes()
if e != nil {
panic(e)
}
encoded := base64.StdEncoding.EncodeToString(decoded)
values := url.Values{}
values.Set("SAMLRequest", encoded)
values.Add("RelayState", relayState)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Body = io.NopCloser(strings.NewReader(values.Encode()))
}
}
type BindingParseResult struct {
Binding string
Values url.Values
Encoded string
Decoded []byte
}
// ParseBinding parse redirect/post binding from given HTTP response
func ParseBinding[T samlutils.ParsableSamlTypes](resp *http.Response, dest *T) (ret BindingParseResult, err error) {
param := samlutils.HttpParamSAMLResponse
var i interface{} = dest
switch i.(type) {
case *saml.LogoutRequest, *saml.AuthnRequest:
param = samlutils.HttpParamSAMLRequest
}
switch {
case resp.StatusCode < 300:
ret.Binding = saml.HTTPPostBinding
ret.Values, err = extractPostBindingValues(resp)
default:
ret.Binding = saml.HTTPRedirectBinding
ret.Values, err = extractRedirectBindingValues(resp)
}
if err != nil {
return
}
ret.Encoded = ret.Values.Get(param)
if len(ret.Encoded) == 0 {
return ret, fmt.Errorf("unable to find %s in http response", param)
}
ret.Decoded, err = base64.StdEncoding.DecodeString(ret.Encoded)
if err != nil {
return
}
// try de-compress
r := flate.NewReader(bytes.NewReader(ret.Decoded))
if data, e := io.ReadAll(r); e == nil {
ret.Decoded = data
}
err = xml.Unmarshal(ret.Decoded, dest)
return
}
func extractRedirectBindingValues(resp *http.Response) (url.Values, error) {
loc := resp.Header.Get("Location")
locUri, e := url.Parse(loc)
if e != nil {
return nil, e
}
return locUri.Query(), nil
}
func extractPostBindingValues(resp *http.Response) (url.Values, error) {
htmlDoc := etree.NewDocument()
if _, e := htmlDoc.ReadFrom(resp.Body); e != nil {
return nil, e
}
values := url.Values{}
elems := htmlDoc.FindElements("//input")
for _, el := range elems {
if typ := el.SelectAttrValue("type", ""); typ == "submit" {
continue
}
name := el.SelectAttrValue("name", "unknown")
value := el.SelectAttrValue("value", "")
values.Add(name, value)
}
return values, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/test/webtest"
)
var DefaultIssuer = security.NewIssuer(func(opt *security.DefaultIssuerDetails) {
*opt =security.DefaultIssuerDetails{
Protocol: "http",
Domain: "vms.com",
Port: 8080,
ContextPath: webtest.DefaultContextPath,
IncludePort: true,
}})
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/crewjam/saml"
)
type MockedSamlAssertionAuthentication struct {
Account security.Account
DetailsMap map[string]interface{}
SamlAssertion *saml.Assertion
}
func (sa *MockedSamlAssertionAuthentication) Principal() interface{} {
return sa.Account
}
func (sa *MockedSamlAssertionAuthentication) Permissions() security.Permissions {
perms := security.Permissions{}
for _, perm := range sa.Account.Permissions() {
perms[perm] = struct{}{}
}
return perms
}
func (sa *MockedSamlAssertionAuthentication) State() security.AuthenticationState {
return security.StateAuthenticated
}
func (sa *MockedSamlAssertionAuthentication) Details() interface{} {
return sa.DetailsMap
}
func (sa *MockedSamlAssertionAuthentication) Assertion() *saml.Assertion {
return sa.SamlAssertion
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"encoding/xml"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/crewjam/saml"
)
type MockedClientOptions func(opt *MockedClientOption)
type MockedClientOption struct {
Properties MockedClientProperties
SP *saml.ServiceProvider
}
type MockedSamlClient struct {
EntityId string
MetadataSource string
SkipAssertionEncryption bool
SkipAuthRequestSignatureVerification bool
MetadataRequireSignature bool
MetadataTrustCheck bool
MetadataTrustedKeys []string
TenantRestrictions utils.StringSet
TenantRestrictionType string
}
func NewMockedSamlClient(opts ...MockedClientOptions) *MockedSamlClient {
opt := MockedClientOption{}
for _, fn := range opts {
fn(&opt)
}
if opt.SP != nil {
metadata := opt.SP.Metadata()
data, e := xml.Marshal(metadata)
if e != nil {
return nil
}
return &MockedSamlClient{
EntityId: opt.SP.EntityID,
MetadataSource: string(data),
TenantRestrictions: utils.NewStringSet(),
TenantRestrictionType: "all",
}
}
return &MockedSamlClient{
EntityId: opt.Properties.EntityID,
MetadataSource: opt.Properties.MetadataSource,
SkipAssertionEncryption: opt.Properties.SkipEncryption,
SkipAuthRequestSignatureVerification: opt.Properties.SkipSignatureVerification,
TenantRestrictions: utils.NewStringSet(opt.Properties.TenantRestriction...),
TenantRestrictionType: opt.Properties.TenantRestrictionType,
}
}
func (c MockedSamlClient) ShouldMetadataRequireSignature() bool {
return c.MetadataRequireSignature
}
func (c MockedSamlClient) ShouldMetadataTrustCheck() bool {
return c.MetadataTrustCheck
}
func (c MockedSamlClient) GetMetadataTrustedKeys() []string {
return c.MetadataTrustedKeys
}
func (c MockedSamlClient) GetEntityId() string {
return c.EntityId
}
func (c MockedSamlClient) GetMetadataSource() string {
return c.MetadataSource
}
func (c MockedSamlClient) ShouldSkipAssertionEncryption() bool {
return c.SkipAssertionEncryption
}
func (c MockedSamlClient) ShouldSkipAuthRequestSignatureVerification() bool {
return c.SkipAuthRequestSignatureVerification
}
func (c MockedSamlClient) GetTenantRestrictions() utils.StringSet {
return c.TenantRestrictions
}
func (c MockedSamlClient) GetTenantRestrictionType() string {
return c.TenantRestrictionType
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"github.com/crewjam/saml"
)
type ClientStoreMockOptions func(opt *ClientStoreMockOption)
type ClientStoreMockOption struct {
Clients []samlctx.SamlClient
SPs []*saml.ServiceProvider
ClientsProperties map[string]MockedClientProperties
}
// ClientsWithPropertiesPrefix returns a ClientStoreMockOptions that bind a map of properties from application config with given prefix
func ClientsWithPropertiesPrefix(appCfg bootstrap.ApplicationConfig, prefix string) ClientStoreMockOptions {
return func(opt *ClientStoreMockOption) {
if e := appCfg.Bind(&opt.ClientsProperties, prefix); e != nil {
panic(e)
}
}
}
// ClientsWithSPs returns a ClientStoreMockOptions that convert given SPs to Clients
func ClientsWithSPs(sps...*saml.ServiceProvider) ClientStoreMockOptions {
return func(opt *ClientStoreMockOption) {
opt.SPs = sps
}
}
type MockSamlClientStore struct {
details []samlctx.SamlClient
}
func NewMockedClientStore(opts...ClientStoreMockOptions) *MockSamlClientStore {
opt := ClientStoreMockOption {}
for _, fn := range opts {
fn(&opt)
}
var details []samlctx.SamlClient
switch {
case len(opt.Clients) > 0:
details = opt.Clients
case len(opt.SPs) > 0:
for _, sp := range opt.SPs {
v := NewMockedSamlClient(func(opt *MockedClientOption) {
opt.SP = sp
})
details = append(details, v)
}
default:
for _, props := range opt.ClientsProperties {
v := NewMockedSamlClient(func(opt *MockedClientOption) {
opt.Properties = props
})
details = append(details, v)
}
}
return &MockSamlClientStore{details: details}
}
func (t *MockSamlClientStore) GetAllSamlClient(_ context.Context) ([]samlctx.SamlClient, error) {
var result []samlctx.SamlClient
for _, v := range t.details {
result = append(result, v)
}
return result, nil
}
func (t *MockSamlClientStore) GetSamlClientByEntityId(_ context.Context, id string) (samlctx.SamlClient, error) {
for _, detail := range t.details {
if detail.GetEntityId() == id {
return detail, nil
}
}
return nil, errors.New("not found")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"crypto/rsa"
"crypto/x509"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
dsig "github.com/russellhaering/goxmldsig"
"net/url"
)
type IDPMockOptions func(opt *IDPMockOption)
type IDPMockOption struct {
Properties IDPProperties
}
// IDPWithPropertiesPrefix returns a IDP mock option that bind properties from application config and with given prefix
func IDPWithPropertiesPrefix(appCfg bootstrap.ApplicationConfig, prefix string) IDPMockOptions {
return func(opt *IDPMockOption) {
if e := appCfg.Bind(&opt.Properties, prefix); e != nil {
panic(e)
}
}
}
// MustNewMockedIDP similar to NewMockedIDP, panic instead of returning error
func MustNewMockedIDP(opts ...IDPMockOptions) *saml.IdentityProvider {
sp, e := NewMockedIDP(opts...)
if e != nil {
panic(e)
}
return sp
}
// NewMockedIDP create a mocked IDP with given IDPMockOptions.
// Returns error if any mocked value are incorrect. e.g. file not exists
func NewMockedIDP(opts ...IDPMockOptions) (*saml.IdentityProvider, error) {
defaultEntityID, _ := DefaultIssuer.BuildUrl()
opt := IDPMockOption{
Properties: IDPProperties{
ProviderProperties: ProviderProperties{
EntityID: defaultEntityID.String(),
},
SSOPath: "/sso",
SLOPath: "/slo",
},
}
for _, fn := range opts {
fn(&opt)
}
var e error
var certs []*x509.Certificate
var privKey *rsa.PrivateKey
var metaUrl, ssoUrl, sloUrl *url.URL
if certs, e = cryptoutils.LoadCert(opt.Properties.CertsSource); e != nil && len(opt.Properties.CertsSource) != 0 {
return nil, e
}
if privKey, e = cryptoutils.LoadPrivateKey(opt.Properties.PrivateKeySource, ""); e != nil && len(opt.Properties.PrivateKeySource) != 0 {
return nil, e
}
if metaUrl, e = resolveAbsUrl(opt.Properties.EntityID, opt.Properties.EntityID); e != nil && len(opt.Properties.EntityID) != 0 {
return nil, e
}
if ssoUrl, e = resolveAbsUrl(opt.Properties.EntityID, opt.Properties.SSOPath); e != nil && len(opt.Properties.SSOPath) != 0 {
return nil, e
}
if sloUrl, e = resolveAbsUrl(opt.Properties.EntityID, opt.Properties.SLOPath); e != nil && len(opt.Properties.SLOPath) != 0 {
return nil, e
}
return &saml.IdentityProvider{
Key: privKey,
Certificate: certs[0],
MetadataURL: *metaUrl,
SSOURL: *ssoUrl,
LogoutURL: *sloUrl,
SignatureMethod: dsig.RSASHA256SignatureMethod,
}, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"context"
"errors"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/security/idp"
samlctx "github.com/cisco-open/go-lanai/pkg/security/saml"
"sort"
)
type MockedIdpManager struct {
idpDetails []idp.IdentityProvider
delegates []idp.IdentityProviderManager
}
type IdpManagerMockOptions func(opt *IdpManagerMockOption)
type IdpManagerMockOption struct {
IDPList []idp.IdentityProvider
IDPProperties map[string]IDPProperties
Delegates []idp.IdentityProviderManager
}
// IDPsWithPropertiesPrefix returns a IdpManagerMockOptions that bind a map of properties from application config and with given prefix
func IDPsWithPropertiesPrefix(appCfg bootstrap.ApplicationConfig, prefix string) IdpManagerMockOptions {
return func(opt *IdpManagerMockOption) {
if e := appCfg.Bind(&opt.IDPProperties, prefix); e != nil {
panic(e)
}
}
}
// IDPsWithFallback returns a IdpManagerMockOptions that set a fallback implementation for non-SAML IDPs
func IDPsWithFallback(delegates ...idp.IdentityProviderManager) IdpManagerMockOptions {
return func(opt *IdpManagerMockOption) {
opt.Delegates = delegates
}
}
// NewMockedIdpManager create a mocked samllogin.SamlIdentityProviderManager that returns SAML IDP based on given options
func NewMockedIdpManager(opts ...IdpManagerMockOptions) *MockedIdpManager {
opt := IdpManagerMockOption{}
for _, fn := range opts {
fn(&opt)
}
var details []idp.IdentityProvider
switch {
case len(opt.IDPList) > 0:
details = opt.IDPList
default:
for _, props := range opt.IDPProperties {
v := NewMockedIdpProvider(func(opt *IDPMockOption) {
opt.Properties = props
})
details = append(details, v)
}
}
return &MockedIdpManager{
idpDetails: details,
delegates: opt.Delegates,
}
}
func (m MockedIdpManager) GetIdentityProvidersWithFlow(ctx context.Context, flow idp.AuthenticationFlow) (ret []idp.IdentityProvider) {
ret = make([]idp.IdentityProvider, len(m.idpDetails), len(m.idpDetails) + 5)
for i, v := range m.idpDetails {
ret[i] = v
}
for _, delegate := range m.delegates {
ret = append(ret, delegate.GetIdentityProvidersWithFlow(ctx, flow)...)
}
sort.SliceStable(ret, func(i, j int) bool {
return ret[i].Domain() < ret[j].Domain()
})
return
}
func (m MockedIdpManager) GetIdentityProviderByEntityId(ctx context.Context, entityId string) (idp.IdentityProvider, error) {
for _, v := range m.idpDetails {
if samlIdp, ok := v.(samlctx.SamlIdentityProvider); ok && entityId == samlIdp.EntityId() {
return v, nil
}
}
for _, delegate := range m.delegates {
samlDelegate, ok := delegate.(samlctx.SamlIdentityProviderManager)
if !ok {
continue
}
if v, e := samlDelegate.GetIdentityProviderByEntityId(ctx, entityId); e == nil {
return v, nil
}
}
return nil, errors.New("not found")
}
func (m MockedIdpManager) GetIdentityProviderByDomain(ctx context.Context, domain string) (idp.IdentityProvider, error) {
for _, v := range m.idpDetails {
if domain == v.Domain() {
return v, nil
}
}
for _, delegate := range m.delegates {
if v, e := delegate.GetIdentityProviderByDomain(ctx, domain); e == nil {
return v, nil
}
}
return nil, errors.New("not found")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import "github.com/cisco-open/go-lanai/pkg/security"
type ExtSamlMetadata struct {
EntityId string
Domain string
Source string
Name string
IdName string
RequireSignature bool
TrustCheck bool
TrustedKeys []string
}
func NewMockedIdpProvider(opts ...IDPMockOptions) *MockedIdpProvider {
defaultEntityID, _ := DefaultIssuer.BuildUrl()
opt := IDPMockOption{
Properties: IDPProperties{
ProviderProperties: ProviderProperties{
EntityID: defaultEntityID.String(),
},
SSOPath: "/sso",
SLOPath: "/slo",
},
}
for _, fn := range opts {
fn(&opt)
}
return &MockedIdpProvider{ExtSamlMetadata{
EntityId: opt.Properties.EntityID,
Domain: opt.Properties.Domain,
Source: opt.Properties.MetadataSource,
Name: opt.Properties.Name,
IdName: opt.Properties.IdName,
}}
}
type MockedIdpProvider struct {
ExtSamlMetadata
}
func (i MockedIdpProvider) Domain() string {
return i.ExtSamlMetadata.Domain
}
func (i MockedIdpProvider) GetAutoCreateUserDetails() security.AutoCreateUserDetails {
return nil
}
func (i MockedIdpProvider) ShouldMetadataRequireSignature() bool {
return i.ExtSamlMetadata.RequireSignature
}
func (i MockedIdpProvider) ShouldMetadataTrustCheck() bool {
return i.ExtSamlMetadata.TrustCheck
}
func (i MockedIdpProvider) GetMetadataTrustedKeys() []string {
return i.ExtSamlMetadata.TrustedKeys
}
func (i MockedIdpProvider) EntityId() string {
return i.ExtSamlMetadata.EntityId
}
func (i MockedIdpProvider) MetadataLocation() string {
return i.ExtSamlMetadata.Source
}
func (i MockedIdpProvider) ExternalIdName() string {
return i.ExtSamlMetadata.IdName
}
func (i MockedIdpProvider) ExternalIdpName() string {
return i.ExtSamlMetadata.Name
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"crypto/rsa"
"crypto/x509"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
dsig "github.com/russellhaering/goxmldsig"
"net/url"
)
type SPMockOptions func(opt *SPMockOption)
type SPMockOption struct {
Properties SPProperties
IDP *saml.IdentityProvider
}
// SPWithPropertiesPrefix returns a SP mock option that bind properties from application config and with given prefix
func SPWithPropertiesPrefix(appCfg bootstrap.ApplicationConfig, prefix string) SPMockOptions {
return func(opt *SPMockOption) {
if e := appCfg.Bind(&opt.Properties, prefix); e != nil {
panic(e)
}
}
}
// SPWithIDP returns a SP mock option that set given IDP
func SPWithIDP(idp *saml.IdentityProvider) SPMockOptions {
return func(opt *SPMockOption) {
opt.IDP = idp
}
}
// MustNewMockedSP similar to NewMockedSP, panic instead of returning error
func MustNewMockedSP(opts ...SPMockOptions) *saml.ServiceProvider {
sp, e := NewMockedSP(opts...)
if e != nil {
panic(e)
}
return sp
}
// NewMockedSP create a mocked SP with given SPMockOptions.
// Returns error if any mocked value are incorrect. e.g. file not exists
func NewMockedSP(opts ...SPMockOptions) (*saml.ServiceProvider, error) {
defaultEntityID, _ := DefaultIssuer.BuildUrl()
opt := SPMockOption{
Properties: SPProperties{
ProviderProperties: ProviderProperties{
EntityID: defaultEntityID.String(),
},
ACSPath: "/acs",
SLOPath: "/slo",
},
}
for _, fn := range opts {
fn(&opt)
}
var e error
var spCerts []*x509.Certificate
var privKey *rsa.PrivateKey
var acsUrl, sloUrl *url.URL
if spCerts, e = cryptoutils.LoadCert(opt.Properties.CertsSource); e != nil {
return nil, e
}
if privKey, e = cryptoutils.LoadPrivateKey(opt.Properties.PrivateKeySource, ""); e != nil && len(opt.Properties.PrivateKeySource) != 0 {
return nil, e
}
if acsUrl, e = resolveAbsUrl(opt.Properties.EntityID, opt.Properties.ACSPath); e != nil {
return nil, e
}
if sloUrl, e = resolveAbsUrl(opt.Properties.EntityID, opt.Properties.SLOPath); e != nil && len(opt.Properties.SLOPath) != 0 {
return nil, e
}
sp := saml.ServiceProvider{
EntityID: opt.Properties.EntityID,
Key: privKey,
Certificate: spCerts[0],
AcsURL: *acsUrl,
SloURL: *sloUrl,
SignatureMethod: dsig.RSASHA256SignatureMethod,
AllowIDPInitiated: true,
AuthnNameIDFormat: saml.UnspecifiedNameIDFormat,
LogoutBindings: []string{saml.HTTPPostBinding},
}
switch {
case opt.IDP != nil:
sp.IDPMetadata = opt.IDP.Metadata()
case opt.Properties.IDP != nil:
idp, e := NewMockedIDP(func(idpopt *IDPMockOption) { idpopt.Properties = *opt.Properties.IDP })
if e == nil {
sp.IDPMetadata = idp.Metadata()
}
}
return &sp, nil
}
func resolveAbsUrl(baseUrl, toResolveUrl string) (*url.URL, error) {
base, e := url.Parse(baseUrl)
if e != nil {
return nil, e
}
toResolve, e := url.Parse(toResolveUrl)
if e != nil {
return nil, e
}
return base.ResolveReference(&url.URL{RawPath: toResolve.RawPath, RawQuery: toResolve.RawQuery, Fragment: toResolve.Fragment}), nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package samltest
import (
"encoding/base64"
"fmt"
"github.com/beevik/etree"
"github.com/cisco-open/go-lanai/pkg/utils/cryptoutils"
"github.com/crewjam/saml"
"github.com/google/uuid"
"net/url"
"time"
)
// MakeAuthnRequest create a SAML AuthnRequest, sign it and returns
func MakeAuthnRequest(sp saml.ServiceProvider, idpUrl string) string {
authnRequest, _ := sp.MakeAuthenticationRequest(idpUrl, saml.HTTPPostBinding, saml.HTTPPostBinding)
doc := etree.NewDocument()
doc.SetRoot(authnRequest.Element())
reqBuf, _ := doc.WriteToBytes()
encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
data := url.Values{}
data.Set("SAMLRequest", encodedReqBuf)
data.Add("RelayState", "my_relay_state")
return data.Encode()
}
type AttributeOptions func(attr *saml.Attribute)
func MockAttribute(name, value string, opts ...AttributeOptions) saml.Attribute {
attr := saml.Attribute{
FriendlyName: name,
Name: name,
NameFormat: "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified",
Values: []saml.AttributeValue{{
Type: "xs:string",
Value: value,
}},
}
for _, fn := range opts {
fn(&attr)
}
return attr
}
type AssertionOptions func(opt *AssertionOption)
type AssertionOption struct {
Issuer string // entity ID
NameID string
NameIDFormat string
Recipient string
Audience string // entity ID
RequestID string
Attributes []saml.Attribute
}
func MockAssertion(opts ...AssertionOptions) *saml.Assertion {
opt := AssertionOption{
NameIDFormat: "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified",
RequestID: uuid.New().String(),
}
for _, fn := range opts {
fn(&opt)
}
now := time.Now()
assertion := &saml.Assertion{
ID: fmt.Sprintf("id-%x", cryptoutils.RandomBytes(20)),
IssueInstant: saml.TimeNow(),
Version: "2.0",
Issuer: saml.Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: opt.Issuer,
},
Subject: &saml.Subject{
NameID: &saml.NameID{
Format: opt.NameIDFormat,
Value: opt.NameID,
},
SubjectConfirmations: []saml.SubjectConfirmation{
{
Method: "urn:oasis:names:tc:SAML:2.0:cm:bearer",
SubjectConfirmationData: &saml.SubjectConfirmationData{
InResponseTo: opt.RequestID,
NotOnOrAfter: now.Add(saml.MaxIssueDelay),
Recipient: opt.Recipient,
},
},
},
},
Conditions: &saml.Conditions{
NotBefore: now,
NotOnOrAfter: now.Add(saml.MaxIssueDelay),
AudienceRestrictions: []saml.AudienceRestriction{
{
Audience: saml.Audience{Value: opt.Audience},
},
},
},
AuthnStatements: []saml.AuthnStatement{
{
AuthnInstant: now,
AuthnContext: saml.AuthnContext{
AuthnContextClassRef: &saml.AuthnContextClassRef{
Value: "urn:oasis:names:tc:SAML:2.0:ac:classes:Password",
},
},
},
},
AttributeStatements: []saml.AttributeStatement{
{
Attributes: opt.Attributes,
},
},
}
return assertion
}
type LogoutResponseOptions func(opt *LogoutResponseOption)
type LogoutResponseOption struct {
Issuer string // entity ID
Recipient string
Audience string // entity ID
RequestID string
Success bool
}
func MockLogoutResponse(opts ...LogoutResponseOptions) *saml.LogoutResponse {
opt := LogoutResponseOption{
RequestID: uuid.New().String(),
Success: true,
}
for _, fn := range opts {
fn(&opt)
}
status := saml.StatusSuccess
if !opt.Success {
status = saml.StatusAuthnFailed
}
now := time.Now()
resp := &saml.LogoutResponse{
ID: fmt.Sprintf("id-%x", uuid.New()),
InResponseTo: opt.RequestID,
Version: "2.0",
IssueInstant: now,
Destination: opt.Recipient,
Issuer: &saml.Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: opt.Issuer,
},
Status: saml.Status{
StatusCode: saml.StatusCode{
Value: status,
},
},
}
return resp
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sdtest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/discovery"
"time"
)
type ClientMock struct {
ctx context.Context
Instancers map[string]*InstancerMock
}
func NewMockClient(ctx context.Context) *ClientMock {
return &ClientMock{
ctx: ctx,
Instancers: map[string]*InstancerMock{},
}
}
/* discovery.Client implementation */
func (c *ClientMock) Context() context.Context {
return c.ctx
}
func (c *ClientMock) Instancer(serviceName string) (discovery.Instancer, error) {
if serviceName == "" {
return nil, fmt.Errorf("empty service name")
}
if instancer, ok := c.Instancers[serviceName]; ok {
return instancer, nil
}
instancer := NewMockInstancer(c.ctx, serviceName)
c.Instancers[serviceName] = instancer
return instancer, nil
}
/* Addtional mock methods */
func (c *ClientMock) MockService(svcName string, count int, opts ...InstanceMockOptions) []*discovery.Instance {
instancer, _ := c.Instancer(svcName)
return instancer.(*InstancerMock).MockInstances(count, opts...)
}
func (c *ClientMock) UpdateMockedService(svcName string, matcher InstanceMockMatcher, opts ...InstanceMockOptions) (count int) {
instancer, ok := c.Instancers[svcName]
if !ok {
return 0
}
return instancer.UpdateInstances(matcher, opts...)
}
func (c *ClientMock) MockError(svcName string, what error, when time.Time) {
instancer, _ := c.Instancer(svcName)
instancer.(*InstancerMock).MockError(what, when)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sdtest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
type InstanceMockOptions func(inst *discovery.Instance)
type InstanceMockMatcher func(inst *discovery.Instance) bool
type InstancerMock struct {
ctx context.Context
SName string
InstanceMocks []*discovery.Instance
ErrTimeMock time.Time
ErrMock error
Started bool
}
func NewMockInstancer(ctx context.Context, svcName string) *InstancerMock {
return &InstancerMock{
ctx: ctx,
SName: svcName,
InstanceMocks: make([]*discovery.Instance, 0, 4),
}
}
/* discovery.Instancer impelementation */
func (i *InstancerMock) ServiceName() string {
return i.SName
}
func (i *InstancerMock) Service() *discovery.Service {
return &discovery.Service{
Name: i.SName,
Insts: i.InstanceMocks,
Time: time.Now(),
Err: i.ErrMock,
FirstErrAt: i.ErrTimeMock,
}
}
func (i *InstancerMock) Instances(matcher discovery.InstanceMatcher) ([]*discovery.Instance, error) {
if i.ErrMock != nil {
return nil, i.ErrMock
}
if matcher == nil {
matcher = discovery.InstanceIsHealthy()
}
ret := make([]*discovery.Instance, 0, len(i.InstanceMocks))
for _, inst := range i.InstanceMocks {
if ok, e := matcher.MatchesWithContext(i.ctx, inst); e == nil && ok {
ret = append(ret, inst)
}
}
return ret, nil
}
func (i *InstancerMock) Start(_ context.Context) {
i.Started = true
}
func (i *InstancerMock) Stop() {
i.Started = false
}
func (i *InstancerMock) RegisterCallback(_ interface{}, _ discovery.Callback) {
// noop
}
func (i *InstancerMock) DeregisterCallback(_ interface{}) {
// noop
}
/* Addtional mock methods */
func (i *InstancerMock) MockInstances(count int, opts ...InstanceMockOptions) []*discovery.Instance {
defer i.resetError()
i.InstanceMocks = make([]*discovery.Instance, count)
for j := 0; j < count; j++ {
var inst = discovery.Instance{
ID: fmt.Sprintf("%d-%s", j, utils.RandomString(10)),
Service: i.SName,
Address: "127.0.0.1",
Port: utils.RandomIntN(32767) + 32768,
Tags: []string{},
Meta: map[string]string{},
Health: discovery.HealthPassing,
}
for _, fn := range opts {
fn(&inst)
}
i.InstanceMocks[j] = &inst
}
return i.InstanceMocks
}
func (i *InstancerMock) UpdateInstances(matcher InstanceMockMatcher, opts ...InstanceMockOptions) (count int) {
defer i.resetError()
for _, inst := range i.InstanceMocks {
if ok := matcher(inst); !ok {
continue
}
for _, fn := range opts {
fn(inst)
}
count ++
}
return
}
func (i *InstancerMock) MockError(what error, when time.Time) {
i.ErrMock = what
i.ErrTimeMock = when
}
func (i *InstancerMock) resetError() {
i.ErrMock = nil
i.ErrTimeMock = time.Time{}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sdtest
import (
"github.com/cisco-open/go-lanai/pkg/discovery"
"strconv"
"strings"
)
func BeHealthy() InstanceMockOptions {
return func(inst *discovery.Instance) {
inst.Health = discovery.HealthPassing
}
}
func BeCritical() InstanceMockOptions {
return func(inst *discovery.Instance) {
inst.Health = discovery.HealthCritical
}
}
func WithExtraTag(tags ...string) InstanceMockOptions {
return func(inst *discovery.Instance) {
inst.Tags = append(inst.Tags, tags...)
}
}
func WithMeta(k, v string) InstanceMockOptions {
return func(inst *discovery.Instance) {
inst.Meta[k] = v
}
}
func AnyInstance() InstanceMockMatcher {
return func(inst *discovery.Instance) bool {
return true
}
}
func AnyHealthyInstance() InstanceMockMatcher {
return func(inst *discovery.Instance) bool {
return inst.Health == discovery.HealthPassing
}
}
func NthInstance(n int) InstanceMockMatcher {
return func(inst *discovery.Instance) bool {
i := extractIndexIfPossible(inst)
return i == n
}
}
func InstanceAfterN(n int) InstanceMockMatcher {
return func(inst *discovery.Instance) bool {
i := extractIndexIfPossible(inst)
return i > n
}
}
func extractIndexIfPossible(inst *discovery.Instance) int {
split := strings.SplitN(inst.ID, "-", 2)
i, e := strconv.Atoi(split[0])
if e != nil {
return -1
}
return i
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package sdtest
// test utilities to mock service discovery client
package sdtest
import (
"context"
"dario.cat/mergo"
"errors"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/discovery"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"github.com/ghodss/yaml"
"go.uber.org/fx"
"io"
"io/fs"
"testing"
)
type DI struct {
fx.In
AppCtx *bootstrap.ApplicationContext
Client *ClientMock
}
type SDMockOptions func(opt *SDMockOption)
type SDMockOption struct {
FS fs.FS
DefPath string
PropertiesPrefix string
}
func WithMockedSD(opts ...SDMockOptions) test.Options {
var di DI
testOpts := []test.Options{
apptest.WithFxOptions(
fx.Provide(ProvideDiscoveryClient),
),
apptest.WithDI(&di),
}
var opt SDMockOption
for _, fn := range opts {
fn(&opt)
}
// load service definitions
switch {
case opt.FS != nil && opt.DefPath != "":
testOpts = append(testOpts, test.SubTestSetup(SetupServicesWithFile(&di, opt.FS, opt.DefPath)))
default:
testOpts = append(testOpts, test.SubTestSetup(SetupServicesWithProperties(&di, opt.PropertiesPrefix)))
}
return test.WithOptions(testOpts...)
}
// LoadDefinition load service discovery mocking from file system, this override DefinitionWithPrefix
func LoadDefinition(fsys fs.FS, path string) SDMockOptions {
return func(opt *SDMockOption) {
opt.FS = fsys
opt.DefPath = path
}
}
// DefinitionWithPrefix load service discovery mocking from application properties, with given prefix
func DefinitionWithPrefix(prefix string) SDMockOptions {
return func(opt *SDMockOption) {
opt.PropertiesPrefix = prefix
}
}
func ProvideDiscoveryClient(ctx *bootstrap.ApplicationContext) (discovery.Client, *ClientMock) {
c := NewMockClient(ctx)
return c, c
}
// SetupServicesWithFile is a test setup function that read service definitions from a YAML file and mock the discovery client
func SetupServicesWithFile(di *DI, fsys fs.FS, path string) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
if di == nil || di.Client == nil {
return nil, errors.New("discovery client mock is not available")
}
e := MockServicesFromFile(di.Client, fsys, path)
return ctx, e
}
}
// SetupServicesWithProperties is a test setup function that read service definitions from properties and mock the discovery client
func SetupServicesWithProperties(di *DI, prefix string) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
if di == nil || di.Client == nil {
return nil, errors.New("discovery client mock is not available")
}
e := MockServicesFromProperties(di.Client, di.AppCtx.Config(), prefix)
return ctx, e
}
}
// MockServicesFromFile read YAML file for mocked service definition and mock ClientMock
func MockServicesFromFile(client *ClientMock, fsys fs.FS, path string) error {
var services map[string][]*discovery.Instance
file, e := fsys.Open(path)
if e != nil {
return e
}
defer func() { _ = file.Close() }()
data, e := io.ReadAll(file)
if e != nil {
return e
}
if e := yaml.Unmarshal(data, &services); e != nil {
return e
}
return MockServices(client, services)
}
// MockServicesFromProperties bind mocked service definitions from properties with given prefix and mock ClientMock
func MockServicesFromProperties(client *ClientMock, appCfg bootstrap.ApplicationConfig, prefix string) error {
var services map[string][]*discovery.Instance
if e := appCfg.Bind(&services, prefix); e != nil {
return e
}
return MockServices(client, services)
}
// MockServices mocks given ClientMock with given services. The key is the service name
func MockServices(client *ClientMock, services map[string][]*discovery.Instance) (err error) {
for k, insts := range services {
var i int
client.MockService(k, len(insts), func(inst *discovery.Instance) {
defer func() { i++ }()
def := insts[i]
if e := mergo.Merge(inst, def, mergo.WithAppendSlice, mergo.WithSliceDeepCopy); e != nil {
err = e
}
if def.Health == discovery.HealthAny {
inst.Health = discovery.HealthPassing
}
})
if err != nil {
break
}
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/session"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/middleware"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"go.uber.org/fx"
"net/http"
)
/**************************
Context
**************************/
// MWMockContext value carrier for mocking authentication in MW
type MWMockContext struct {
Request *http.Request
}
// MWMocker interface that mocked authentication middleware uses to mock authentication at runtime
type MWMocker interface {
Mock(MWMockContext) security.Authentication
}
/**************************
Test Options
**************************/
type MWMockOptions func(opt *MWMockOption)
type MWMockOption struct {
Route web.RouteMatcher
Condition web.RequestMatcher
MWMocker MWMocker
MWOrder int
Configurer security.Configurer
Session bool
ForceOverride bool
}
var defaultMWMockOption = MWMockOption{
MWOrder: security.MWOrderPreAuth + 5,
MWMocker: DirectExtractionMWMocker{},
Route: matcher.AnyRoute(),
}
// WithMockedMiddleware is a test option that automatically install a middleware that populate/save
// security.Authentication into gin.Context.
//
// This test option works with webtest.WithMockedServer without any additional settings:
// - By default extract security.Authentication from request's context.
// Note: Since gin-gonic v1.8.0+, this test option is not required anymore for webtest.WithMockedServer. Values in
// request's context is automatically linked with gin.Context.
//
// When using with webtest.WithRealServer, a custom MWMocker is required. The MWMocker can be provided by:
// - Using MWCustomMocker option
// - Providing a MWMocker using uber/fx
// - Providing a security.Configurer with NewMockedMW:
// <code>
// func realServerSecConfigurer(ws security.WebSecurity) {
// ws.Route(matcher.AnyRoute()).
// With(NewMockedMW().
// Mocker(MWMockFunc(realServerMockFunc)),
// )
// }
// </code>
//
// See examples package for more details.
func WithMockedMiddleware(opts ...MWMockOptions) test.Options {
opt := defaultMWMockOption
for _, fn := range opts {
fn(&opt)
}
testOpts := []test.Options{
apptest.WithModules(security.Module),
apptest.WithFxOptions(
fx.Invoke(registerSecTest),
),
}
if opt.MWMocker != nil {
testOpts = append(testOpts, apptest.WithFxOptions(fx.Provide(func() MWMocker { return opt.MWMocker })))
}
if opt.Configurer != nil {
testOpts = append(testOpts, apptest.WithFxOptions(fx.Invoke(func(reg security.Registrar) {
reg.Register(opt.Configurer)
})))
} else {
testOpts = append(testOpts, apptest.WithFxOptions(fx.Invoke(RegisterTestConfigurer(opts...))))
}
if opt.Session {
testOpts = append(testOpts,
apptest.WithModules(session.Module),
apptest.WithFxOptions(fx.Decorate(MockedSessionStoreDecorator)),
)
}
return test.WithOptions(testOpts...)
}
// MWRoute returns option for WithMockedMiddleware.
// This route is applied to the default test security.Configurer
func MWRoute(matchers ...web.RouteMatcher) MWMockOptions {
return func(opt *MWMockOption) {
for i, m := range matchers {
if i == 0 {
opt.Route = m
} else {
opt.Route = opt.Route.Or(m)
}
}
}
}
// MWCondition returns option for WithMockedMiddleware.
// This condition is applied to the default test security.Configurer
func MWCondition(matchers ...web.RequestMatcher) MWMockOptions {
return func(opt *MWMockOption) {
for i, m := range matchers {
if i == 0 {
opt.Condition = m
} else {
opt.Condition = opt.Route.Or(m)
}
}
}
}
// MWEnableSession returns option for WithMockedMiddleware.
// Enabling in-memory session
func MWEnableSession() MWMockOptions {
return func(opt *MWMockOption) {
opt.Session = true
}
}
// MWForcePreOAuth2AuthValidation returns option for WithMockedMiddleware.
// Decrease the order of mocking middleware such that it runs before OAuth2 authorize validation.
func MWForcePreOAuth2AuthValidation() MWMockOptions {
return func(opt *MWMockOption) {
opt.MWOrder = security.MWOrderOAuth2AuthValidation - 5
}
}
// MWForceOverride returns option for WithMockedMiddleware.
// Add a middleware after the last auth middleware (before access control) that override any other installed authenticators.
func MWForceOverride() MWMockOptions {
return func(opt *MWMockOption) {
opt.ForceOverride = true
}
}
// MWCustomConfigurer returns option for WithMockedMiddleware.
// If set to nil, MWMockOption.Route and MWMockOption.Condition are used to generate a default configurer
// If set to non-nil, MWMockOption.Route and MWMockOption.Condition are ignored
func MWCustomConfigurer(configurer security.Configurer) MWMockOptions {
return func(opt *MWMockOption) {
opt.Configurer = configurer
}
}
// MWCustomMocker returns option for WithMockedMiddleware.
// If set to nil, fx provided MWMocker will be used
func MWCustomMocker(mocker MWMocker) MWMockOptions {
return func(opt *MWMockOption) {
opt.MWMocker = mocker
}
}
/**************************
Mockers
**************************/
// MWMockFunc wrap a function to MWMocker interface
type MWMockFunc func(MWMockContext) security.Authentication
func (f MWMockFunc) Mock(mc MWMockContext) security.Authentication {
return f(mc)
}
// DirectExtractionMWMocker is an MWMocker that extracts authentication from context.
// This is the implementation is works together with webtest.WithMockedServer and WithMockedSecurity,
// where a context is injected with security.Authentication and directly passed into http.Request
type DirectExtractionMWMocker struct{}
func (m DirectExtractionMWMocker) Mock(mc MWMockContext) security.Authentication {
return security.Get(mc.Request.Context())
}
/**************************
Feature
**************************/
var (
FeatureId = security.FeatureId("SecTest", security.FeatureOrderAuthenticator)
)
type regDI struct {
fx.In
SecRegistrar security.Registrar `optional:"true"`
}
func registerSecTest(di regDI) {
if di.SecRegistrar != nil {
configurer := newFeatureConfigurer()
di.SecRegistrar.(security.FeatureRegistrar).RegisterFeature(FeatureId, configurer)
}
}
type Feature struct {
MWOrder int
MWMocker MWMocker
Override bool
}
// NewMockedMW Standard security.Feature entrypoint, DSL style. Used with security.WebSecurity
func NewMockedMW() *Feature {
return &Feature{
MWOrder: defaultMWMockOption.MWOrder,
MWMocker: defaultMWMockOption.MWMocker,
}
}
func (f *Feature) Order(mwOrder int) *Feature {
f.MWOrder = mwOrder
return f
}
func (f *Feature) Mocker(mocker MWMocker) *Feature {
f.MWMocker = mocker
return f
}
func (f *Feature) ForceOverride(override bool) *Feature {
f.Override = override
return f
}
func (f *Feature) MWMockFunc(mocker MWMockFunc) *Feature {
f.MWMocker = mocker
return f
}
func (f *Feature) Identifier() security.FeatureIdentifier {
return FeatureId
}
func Configure(ws security.WebSecurity) *Feature {
feature := NewMockedMW()
if fc, ok := ws.(security.FeatureModifier); ok {
return fc.Enable(feature).(*Feature)
}
panic(fmt.Errorf("unable to configure session: provided WebSecurity [%T] doesn't support FeatureModifier", ws))
}
type FeatureConfigurer struct {
}
func newFeatureConfigurer() *FeatureConfigurer {
return &FeatureConfigurer{}
}
func (c *FeatureConfigurer) Apply(feature security.Feature, ws security.WebSecurity) error {
f := feature.(*Feature)
mock := &MockAuthenticationMiddleware{
MWMocker: f.MWMocker,
}
mw := middleware.NewBuilder("mocked-auth-mw").
Order(f.MWOrder).
Use(mock.AuthenticationHandlerFunc())
ws.Add(mw)
if f.Override {
overrideMW := middleware.NewBuilder("mocked-auth-override-mw").
Order(security.MWOrderAccessControl - 5).
Use(mock.ForceOverrideHandlerFunc())
ws.Add(overrideMW)
}
return nil
}
/**************************
Security Configurer
**************************/
type mwDI struct {
fx.In
Registrar security.Registrar `optional:"true"`
Mocker MWMocker `optional:"true"`
}
func RegisterTestConfigurer(opts ...MWMockOptions) func(di mwDI) {
opt := defaultMWMockOption
for _, fn := range opts {
fn(&opt)
}
return func(di mwDI) {
if opt.MWMocker == nil {
opt.MWMocker = di.Mocker
}
configurer := security.ConfigurerFunc(newTestSecurityConfigurer(&opt))
di.Registrar.Register(configurer)
}
}
func newTestSecurityConfigurer(opt *MWMockOption) func(ws security.WebSecurity) {
return func(ws security.WebSecurity) {
ws = ws.Route(opt.Route).With(NewMockedMW().
Order(opt.MWOrder).
Mocker(opt.MWMocker).
ForceOverride(opt.ForceOverride),
)
if opt.Condition != nil {
ws.Condition(opt.Condition)
}
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Examples. The reason this file exists is to work around an issue existed since go 1.3:
// https://github.com/golang/go/issues/8279
// Note: this issue has been fixed in 1.17
package examples
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/integrate/security/scope"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/access"
"github.com/cisco-open/go-lanai/pkg/security/basicauth"
"github.com/cisco-open/go-lanai/pkg/security/errorhandling"
"github.com/cisco-open/go-lanai/pkg/security/redirect"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/cisco-open/go-lanai/pkg/web/matcher"
"github.com/cisco-open/go-lanai/pkg/web/rest"
"net/http"
)
/*************************
Examples Setup
*************************/
type TestTarget struct{}
func (t *TestTarget) DoSomethingWithinSecurityScope(ctx context.Context) error {
e := scope.Do(ctx, func(scopedCtx context.Context) {
// scopedCtx contains switched security context
// do something with scopedCtx...
_ = t.DoSomethingRequiringSecurity(scopedCtx)
}, scope.UseSystemAccount())
return e
}
func (t *TestTarget) DoSomethingRequiringSecurity(ctx context.Context) error {
auth := security.Get(ctx)
if !security.IsFullyAuthenticated(auth) {
return fmt.Errorf("not authenticated")
}
return nil
}
const (
TestSecuredURL = "/api/v1/secured"
TestEntryPointURL = "/login"
)
type TestController struct{}
func registerTestController(reg *web.Registrar) {
reg.MustRegister(&TestController{})
}
func (c *TestController) Mappings() []web.Mapping {
return []web.Mapping{
rest.New("secured-get").Get(TestSecuredURL).
EndpointFunc(c.Secured).Build(),
rest.New("secured-post").Post(TestSecuredURL).
EndpointFunc(c.Secured).Build(),
}
}
func (c *TestController) Secured(_ context.Context, _ *http.Request) (interface{}, error) {
return map[string]interface{}{
"Message": "Yes",
}, nil
}
type TestSecConfigurer struct{}
func (c *TestSecConfigurer) Configure(ws security.WebSecurity) {
ws.Route(matcher.RouteWithPattern("/api/**")).
With(
basicauth.New().EntryPoint(redirect.NewRedirectWithRelativePath(TestEntryPointURL, false)),
).
With(access.New().Request(matcher.AnyRequest()).Authenticated()).
With(errorhandling.New())
}
func registerTestSecurity(registrar security.Registrar) {
cfg := TestSecConfigurer{}
registrar.Register(&cfg)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"strings"
)
const (
idPrefix = "id-"
)
/*************************
Account Auth
*************************/
type MockedAccountAuthentication struct {
Account MockedAccount
AuthState security.AuthenticationState
DetailsMap map[string]interface{}
}
func (a MockedAccountAuthentication) Principal() interface{} {
return a.Account
}
func (a MockedAccountAuthentication) Permissions() security.Permissions {
perms := security.Permissions{}
for perm := range a.Account.MockedAccountDetails.Permissions {
perms[perm] = struct{}{}
}
return perms
}
func (a MockedAccountAuthentication) State() security.AuthenticationState {
return a.AuthState
}
func (a MockedAccountAuthentication) Details() interface{} {
return a.DetailsMap
}
/*************************
Account
*************************/
type MockedAccountDetails struct {
UserId string
Type security.AccountType
Username string
Password string
TenantId string
DefaultTenant string
AssignedTenants utils.StringSet
Permissions utils.StringSet
}
type MockedAccount struct {
MockedAccountDetails
}
func (m MockedAccount) DefaultDesignatedTenantId() string {
return m.DefaultTenant
}
func (m MockedAccount) DesignatedTenantIds() []string {
return m.AssignedTenants.Values()
}
func (m MockedAccount) TenantId() string {
return m.MockedAccountDetails.TenantId
}
func (m MockedAccount) ID() interface{} {
return m.UserId
}
func (m MockedAccount) Type() security.AccountType {
return m.MockedAccountDetails.Type
}
func (m MockedAccount) Username() string {
return m.MockedAccountDetails.Username
}
func (m MockedAccount) Credentials() interface{} {
return m.MockedAccountDetails.Password
}
func (m MockedAccount) Permissions() []string {
return m.MockedAccountDetails.Permissions.Values()
}
func (m MockedAccount) Disabled() bool {
return false
}
func (m MockedAccount) Locked() bool {
return false
}
func (m MockedAccount) UseMFA() bool {
return false
}
func (m MockedAccount) CacheableCopy() security.Account {
return m
}
func newMockedAccount(props *MockedAccountProperties) *MockedAccount {
ret := &MockedAccount{
MockedAccountDetails{
UserId: props.UserId,
Type: security.AccountTypeApp,
Username: props.Username,
Password: props.Password,
DefaultTenant: props.DefaultTenant,
AssignedTenants: utils.NewStringSet(props.Tenants...),
Permissions: utils.NewStringSet(props.Perms...),
},
}
switch {
case ret.UserId == "":
ret.UserId = extIdToId(ret.MockedAccountDetails.Username)
case ret.MockedAccountDetails.Username == "":
ret.MockedAccountDetails.Username = idToExtId(ret.UserId)
}
return ret
}
type mockedAccounts struct {
idLookup map[string]*MockedAccount
lookup map[string]*MockedAccount
}
func newMockedAccounts(acctProps map[string]*MockedAccountProperties) *mockedAccounts {
accts := mockedAccounts{
idLookup: map[string]*MockedAccount{},
lookup: map[string]*MockedAccount{},
}
for _, v := range acctProps {
acct := newMockedAccount(v)
if acct.MockedAccountDetails.Username != "" {
accts.lookup[acct.MockedAccountDetails.Username] = acct
}
if acct.UserId != "" {
accts.idLookup[acct.UserId] = acct
}
}
return &accts
}
func (m mockedAccounts) find(username, userId string) *MockedAccount {
if v, ok := m.lookup[username]; ok && (userId == "" || v.UserId == userId) {
return v
}
if v, ok := m.idLookup[userId]; ok && (username == "" || v.MockedAccountDetails.Username == username) {
return v
}
return nil
}
func (m mockedAccounts) idToName(id string) string {
if u, ok := m.idLookup[id]; ok {
return u.MockedAccountDetails.Username
}
return idToExtId(id)
}
func (m mockedAccounts) nameToId(name string) string {
if u, ok := m.lookup[name]; ok {
return u.UserId
}
return extIdToId(name)
}
type mockedTenant struct {
ExternalId string
ProviderId string
ID string
// the mockedTenant Permissions is a map of MockedAccountDetails.UserId to
// slice of permissions. This is defined so that we can define per-tenant
// permissions in a mocked setting. See MockAccountStoreWithFinalize for
// examples on this usage.
Permissions map[string][]string
}
func newMockedTenant(props *MockedTenantProperties) *mockedTenant {
ret := &mockedTenant{
ExternalId: props.ExternalId,
ID: props.ID,
}
switch {
case ret.ID == "":
ret.ID = extIdToId(ret.ExternalId)
case ret.ExternalId == "":
ret.ExternalId = idToExtId(ret.ID)
}
return ret
}
/*************************
Tenant
*************************/
type mockedTenants struct {
idLookup map[string]*mockedTenant
extIdLookup map[string]*mockedTenant
}
func newMockedTenants(tenantProps map[string]*MockedTenantProperties) *mockedTenants {
tenants := mockedTenants{
idLookup: map[string]*mockedTenant{},
extIdLookup: map[string]*mockedTenant{},
}
for _, v := range tenantProps {
t := newMockedTenant(v)
if t.ExternalId != "" {
tenants.extIdLookup[t.ExternalId] = t
}
if t.ID != "" {
tenants.idLookup[t.ID] = t
}
}
return &tenants
}
func (m mockedTenants) find(tenantId, tenantExternalId string) *mockedTenant {
if v, ok := m.idLookup[tenantId]; ok && (tenantExternalId == "" || v.ExternalId == tenantExternalId) {
return v
}
if v, ok := m.extIdLookup[tenantExternalId]; ok && (tenantId == "" || v.ID == tenantId) {
return v
}
return nil
}
func (m mockedTenants) idToExtId(id string) string {
if t, ok := m.idLookup[id]; ok {
return t.ExternalId
}
return idToExtId(id)
}
func (m mockedTenants) extIdToId(name string) string {
if t, ok := m.extIdLookup[name]; ok {
return t.ID
}
return extIdToId(name)
}
/*************************
Helpers
*************************/
func idToExtId(id string) string {
return strings.TrimPrefix(id, idPrefix)
}
func extIdToId(extId string) string {
return idPrefix + extId
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"errors"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
)
type MockAccountStoreWithFinalize struct {
MockAccountStore
tenantIDLookup map[string]*mockedTenant
tenantExtIDLookup map[string]*mockedTenant
}
func NewMockedAccountStoreWithFinalize(accountProps []*MockedAccountProperties, tenantProps []*MockedTenantProperties, modifiers ...MockedAccountModifier) *MockAccountStoreWithFinalize {
store := &MockAccountStoreWithFinalize{
MockAccountStore: *NewMockedAccountStore(accountProps, modifiers...),
tenantIDLookup: map[string]*mockedTenant{},
tenantExtIDLookup: map[string]*mockedTenant{},
}
for _, v := range tenantProps {
t := newTenant(v)
if len(t.ExternalId) != 0 {
store.tenantExtIDLookup[t.ExternalId] = t
}
if len(t.ID) != 0 {
store.tenantIDLookup[t.ID] = t
}
}
return store
}
// Finalize will read the tenant details from the security.AccountFinalizeOption and
// adjust the user permission depending on which tenant is selected.
// Note that permissions vary depending on the combination of user + tenant.
// User1 with Tenant1 can have different permissions than User2 with Tenant1.
func (m *MockAccountStoreWithFinalize) Finalize(
ctx context.Context,
account security.Account,
options ...security.AccountFinalizeOptions,
) (security.Account, error) {
var opts security.AccountFinalizeOption
for _, option := range options {
option(&opts)
}
u, ok := m.accountLookupByUsername[account.Username()]
if !ok {
return nil, fmt.Errorf("username: %v not found", account.Username())
}
ret := *u
ret.MockedAccountDetails.DefaultTenant = account.(security.AccountTenancy).DefaultDesignatedTenantId()
ret.MockedAccountDetails.AssignedTenants = utils.NewStringSet(account.(security.AccountTenancy).DesignatedTenantIds()...)
if opts.Tenant == nil {
ret.MockedAccountDetails.Permissions = utils.NewStringSet(security.SpecialPermissionSwitchTenant)
return ret, nil
}
tenant, ok := m.tenantIDLookup[opts.Tenant.Id]
if !ok {
return nil, fmt.Errorf("tenantID: %v not found", opts.Tenant.Id)
}
if permissions, ok := tenant.Permissions[account.ID().(string)]; ok {
ret.MockedAccountDetails.Permissions = utils.NewStringSet(permissions...)
}
return ret, nil
}
// MockedAccountModifier works with MockAccountStore. It allows tests to modify the mocked account after load
type MockedAccountModifier func(acct security.Account) security.Account
type MockAccountStore struct {
accountLookupByUsername map[string]*MockedAccount
accountLookupById map[interface{}]*MockedAccount
modifiers []MockedAccountModifier
}
func NewMockedAccountStore(accountProps []*MockedAccountProperties, modifiers ...MockedAccountModifier) *MockAccountStore {
store := &MockAccountStore{
accountLookupById: make(map[interface{}]*MockedAccount),
accountLookupByUsername: make(map[string]*MockedAccount),
modifiers: modifiers,
}
for _, v := range accountProps {
acct := newMockedAccount(v)
if acct.Username() != "" {
store.accountLookupByUsername[acct.Username()] = acct
}
if acct.UserId != "" {
store.accountLookupById[acct.UserId] = acct
}
}
return store
}
func (m *MockAccountStore) LoadAccountById(_ context.Context, id interface{}) (security.Account, error) {
u, ok := m.accountLookupById[id]
if !ok {
return nil, errors.New("user ID not found")
}
var acct security.Account = u
for _, modifier := range m.modifiers {
acct = modifier(acct)
}
if acct == nil {
return nil, errors.New("user ID not found")
}
return acct, nil
}
func (m *MockAccountStore) LoadAccountByUsername(_ context.Context, username string) (security.Account, error) {
u, ok := m.accountLookupByUsername[username]
if !ok {
return nil, errors.New("username not found")
}
var acct security.Account = u
for _, modifier := range m.modifiers {
acct = modifier(acct)
}
if acct == nil {
return nil, errors.New("username not found")
}
return acct, nil
}
func (m *MockAccountStore) LoadLockingRules(ctx context.Context, acct security.Account) (security.AccountLockingRule, error) {
loaded, e := m.LoadAccountById(ctx, acct.ID())
if e != nil {
return nil, e
}
if v, ok := loaded.(security.AccountLockingRule); ok {
return v, nil
}
return &security.DefaultAccount{
AcctLockingRule: security.AcctLockingRule{
Name: "test-noop",
},
}, nil
}
func (m *MockAccountStore) LoadPwdAgingRules(ctx context.Context, acct security.Account) (security.AccountPwdAgingRule, error) {
loaded, e := m.LoadAccountById(ctx, acct.ID())
if e != nil {
return nil, e
}
if v, ok := loaded.(security.AccountPwdAgingRule); ok {
return v, nil
}
return &security.DefaultAccount{
AcctPasswordPolicy: security.AcctPasswordPolicy{
Name: "test-noop",
},
}, nil
}
func (m *MockAccountStore) Save(_ context.Context, _ security.Account) error {
return nil
}
type MockedFederatedAccountStore struct {
mocks []*MockedFederatedUserProperties
}
func NewMockedFederatedAccountStore(props ...*MockedFederatedUserProperties) MockedFederatedAccountStore {
if len(props) == 0 {
props = []*MockedFederatedUserProperties{
{
ExtIdpName: "*",
ExtIdName: "*",
ExtIdValue: "*",
},
}
}
return MockedFederatedAccountStore{mocks: props}
}
// LoadAccountByExternalId The externalIdName and value matches the test assertion
// The externalIdp matches that from the MockedIdpName
func (s MockedFederatedAccountStore) LoadAccountByExternalId(_ context.Context, extIdName string, extIdValue string, extIdpName string, _ security.AutoCreateUserDetails, _ interface{}) (security.Account, error) {
for i := range s.mocks {
p := s.mocks[i]
if extIdName != p.ExtIdName && p.ExtIdName != "*" ||
extIdValue != p.ExtIdValue && p.ExtIdValue != "*" ||
extIdpName != p.ExtIdpName && p.ExtIdpName != "*" {
continue
}
p.UserId = s.withDefault(p.UserId, fmt.Sprintf("ext-%s-%s", extIdName, extIdValue))
acct := newMockedAccount(&p.MockedAccountProperties)
acct.MockedAccountDetails.Type = security.AccountTypeFederated
return acct, nil
}
return nil, fmt.Errorf("unable to find federated user by extIdName=%s, extIdValue=%s, extIdpName=%s", extIdName, extIdValue, extIdpName)
}
func (s MockedFederatedAccountStore) withDefault(val, defaultVal string) string {
if len(val) == 0 {
return defaultVal
}
return val
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
)
type ckMockedAuth struct{}
type MockAuthenticationMiddleware struct {
MWMocker MWMocker
// deprecated, use MWMocker interface or MWMockFunc.
// Recommended to use WithMockedMiddleware test options
MockedAuthentication security.Authentication
}
// NewMockAuthenticationMiddleware
// Deprecated, directly set MWMocker field with MWMocker interface or MWMockFunc, Recommended to use WithMockedMiddleware test options
func NewMockAuthenticationMiddleware(authentication security.Authentication) *MockAuthenticationMiddleware {
return &MockAuthenticationMiddleware{
MockedAuthentication: authentication,
MWMocker: MWMockFunc(func(MWMockContext) security.Authentication {
return authentication
}),
}
}
func (m *MockAuthenticationMiddleware) AuthenticationHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
var auth security.Authentication
if m.MWMocker != nil {
auth = m.MWMocker.Mock(MWMockContext{
Request: ctx.Request,
})
}
if auth == nil {
auth = m.MockedAuthentication
}
security.MustSet(ctx, auth)
web.SetKV(ctx, ckMockedAuth{}, auth)
}
}
func (m *MockAuthenticationMiddleware) ForceOverrideHandlerFunc() gin.HandlerFunc {
return func(ctx *gin.Context) {
auth, ok := ctx.Value(ckMockedAuth{}).(security.Authentication)
if ok {
security.MustSet(ctx, auth)
}
}
}
type MockUserAuthOptions func(opt *MockUserAuthOption)
type MockUserAuthOption struct {
Principal string
Permissions map[string]interface{}
State security.AuthenticationState
Details interface{}
}
type mockUserAuthentication struct {
Subject string
PermissionMap map[string]interface{}
StateValue security.AuthenticationState
details interface{}
}
func NewMockedUserAuthentication(opts ...MockUserAuthOptions) *mockUserAuthentication {
opt := MockUserAuthOption{}
for _, f := range opts {
f(&opt)
}
return &mockUserAuthentication{
Subject: opt.Principal,
PermissionMap: opt.Permissions,
StateValue: opt.State,
details: opt.Details,
}
}
func (a *mockUserAuthentication) Principal() interface{} {
return a.Subject
}
func (a *mockUserAuthentication) Permissions() security.Permissions {
return a.PermissionMap
}
func (a *mockUserAuthentication) State() security.AuthenticationState {
return a.StateValue
}
func (a *mockUserAuthentication) Details() interface{} {
return a.details
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
var (
defaultClientGrantTypes = utils.NewStringSet(
oauth2.GrantTypeClientCredentials,
oauth2.GrantTypePassword,
oauth2.GrantTypeAuthCode,
oauth2.GrantTypeImplicit,
oauth2.GrantTypeRefresh,
oauth2.GrantTypeSwitchUser,
oauth2.GrantTypeSwitchTenant,
oauth2.GrantTypeSamlSSO,
)
defaultClientScopes = utils.NewStringSet(
oauth2.ScopeRead, oauth2.ScopeWrite,
oauth2.ScopeTokenDetails, oauth2.ScopeTenantHierarchy,
oauth2.ScopeOidc, oauth2.ScopeOidcProfile, oauth2.ScopeOidcEmail,
oauth2.ScopeOidcAddress, oauth2.ScopeOidcPhone,
)
)
type MockedClient struct {
MockedClientProperties
}
func (m MockedClient) ID() interface{} {
return m.MockedClientProperties.ClientID
}
func (m MockedClient) Type() security.AccountType {
return security.AccountTypeDefault
}
func (m MockedClient) Username() string {
return m.MockedClientProperties.ClientID
}
func (m MockedClient) Credentials() interface{} {
return m.MockedClientProperties.Secret
}
func (m MockedClient) Permissions() []string {
return nil
}
func (m MockedClient) Disabled() bool {
return false
}
func (m MockedClient) Locked() bool {
return false
}
func (m MockedClient) UseMFA() bool {
return false
}
func (m MockedClient) CacheableCopy() security.Account {
cp := MockedClient{
m.MockedClientProperties,
}
cp.MockedClientProperties.Secret = ""
return cp
}
func (m MockedClient) ClientId() string {
return m.MockedClientProperties.ClientID
}
func (m MockedClient) SecretRequired() bool {
return len(m.MockedClientProperties.Secret) != 0
}
func (m MockedClient) Secret() string {
return m.MockedClientProperties.Secret
}
func (m MockedClient) GrantTypes() utils.StringSet {
if m.MockedClientProperties.GrantTypes == nil {
return defaultClientGrantTypes
}
return utils.NewStringSet(m.MockedClientProperties.GrantTypes...)
}
func (m MockedClient) RedirectUris() utils.StringSet {
return utils.NewStringSet(m.MockedClientProperties.RedirectUris...)
}
func (m MockedClient) Scopes() utils.StringSet {
if m.MockedClientProperties.Scopes == nil {
return defaultClientScopes
}
return utils.NewStringSet(m.MockedClientProperties.Scopes...)
}
func (m MockedClient) AutoApproveScopes() utils.StringSet {
if m.MockedClientProperties.AutoApproveScopes == nil {
return m.Scopes()
}
return utils.NewStringSet(m.MockedClientProperties.AutoApproveScopes...)
}
func (m MockedClient) AccessTokenValidity() time.Duration {
return time.Duration(m.MockedClientProperties.ATValidity)
}
func (m MockedClient) RefreshTokenValidity() time.Duration {
return time.Duration(m.MockedClientProperties.RTValidity)
}
func (m MockedClient) UseSessionTimeout() bool {
return true
}
func (m MockedClient) AssignedTenantIds() utils.StringSet {
return utils.NewStringSet(m.MockedClientProperties.AssignedTenantIds...)
}
func (m MockedClient) ResourceIDs() utils.StringSet {
return utils.NewStringSet()
}
type MockedClientStore struct {
idLookup map[string]*MockedClient
}
func NewMockedClientStore(props ...*MockedClientProperties) *MockedClientStore {
ret := MockedClientStore{
idLookup: map[string]*MockedClient{},
}
for _, v := range props {
ret.idLookup[v.ClientID] = &MockedClient{MockedClientProperties: *v}
}
return &ret
}
func (s *MockedClientStore) LoadClientByClientId(_ context.Context, clientId string) (oauth2.OAuth2Client, error) {
if c, ok := s.idLookup[clientId]; ok {
return c, nil
}
return nil, fmt.Errorf("cannot find client with client ID [%s]", clientId)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/utils"
"time"
)
type SecurityMockOptions func(d *SecurityDetailsMock)
type SecurityDetailsMock struct {
Username string
UserId string
AccountType security.AccountType
TenantExternalId string
TenantId string
ProviderName string
ProviderId string
ProviderDisplayName string
ProviderDescription string
ProviderEmail string
ProviderNotificationType string
AccessToken string
Exp time.Time
Iss time.Time
Permissions utils.StringSet
Roles utils.StringSet
Tenants utils.StringSet
OrigUsername string
UserFirstName string
UserLastName string
KVs map[string]interface{}
ClientID string
Scopes utils.StringSet
OAuth2GrantType string
OAuth2ResponseTypes utils.StringSet
OAuth2Parameters map[string]string
OAuth2Extensions map[string]interface{}
}
// MockedSecurityDetails implements
// - security.AuthenticationDetails
// - security.ProxiedUserDetails
// - security.UserDetails
// - security.TenantDetails
// - security.ProviderDetails
// - security.KeyValueDetails
// - internal.TenantAccessDetails
type MockedSecurityDetails struct {
SecurityDetailsMock
}
func NewMockedSecurityDetails(opts ...SecurityMockOptions) *MockedSecurityDetails {
ret := MockedSecurityDetails{
SecurityDetailsMock{
AccountType: security.AccountTypeDefault,
ClientID: "mock",
},
}
for _, fn := range opts {
fn(&ret.SecurityDetailsMock)
}
return &ret
}
func (d *MockedSecurityDetails) Value(s string) (interface{}, bool) {
v, ok := d.KVs[s]
return v, ok
}
func (d *MockedSecurityDetails) Values() map[string]interface{} {
return d.KVs
}
func (d *MockedSecurityDetails) OriginalUsername() string {
return d.OrigUsername
}
func (d *MockedSecurityDetails) Proxied() bool {
return d.OrigUsername != ""
}
func (d *MockedSecurityDetails) ExpiryTime() time.Time {
return d.Exp
}
func (d *MockedSecurityDetails) IssueTime() time.Time {
return d.Iss
}
func (d *MockedSecurityDetails) Roles() utils.StringSet {
return d.SecurityDetailsMock.Roles
}
func (d *MockedSecurityDetails) Permissions() utils.StringSet {
if d.SecurityDetailsMock.Permissions == nil {
d.SecurityDetailsMock.Permissions = utils.NewStringSet()
}
return d.SecurityDetailsMock.Permissions
}
func (d *MockedSecurityDetails) AuthenticationTime() time.Time {
return d.Iss
}
func (d *MockedSecurityDetails) ProviderId() string {
return d.SecurityDetailsMock.ProviderId
}
func (d *MockedSecurityDetails) ProviderName() string {
return d.SecurityDetailsMock.ProviderName
}
func (d *MockedSecurityDetails) ProviderDisplayName() string {
return d.SecurityDetailsMock.ProviderDisplayName
}
func (d *MockedSecurityDetails) ProviderDescription() string {
return d.SecurityDetailsMock.ProviderDescription
}
func (d *MockedSecurityDetails) ProviderEmail() string {
return d.SecurityDetailsMock.ProviderEmail
}
func (d *MockedSecurityDetails) ProviderNotificationType() string {
return d.SecurityDetailsMock.ProviderNotificationType
}
func (d *MockedSecurityDetails) TenantId() string {
return d.SecurityDetailsMock.TenantId
}
func (d *MockedSecurityDetails) TenantExternalId() string {
return d.SecurityDetailsMock.TenantExternalId
}
func (d *MockedSecurityDetails) TenantSuspended() bool {
return false
}
func (d *MockedSecurityDetails) UserId() string {
return d.SecurityDetailsMock.UserId
}
func (d *MockedSecurityDetails) Username() string {
return d.SecurityDetailsMock.Username
}
func (d *MockedSecurityDetails) AccountType() security.AccountType {
return d.SecurityDetailsMock.AccountType
}
// Deprecated: the interface is deprecated
func (d *MockedSecurityDetails) AssignedTenantIds() utils.StringSet {
return d.EffectiveAssignedTenantIds()
}
func (d *MockedSecurityDetails) EffectiveAssignedTenantIds() utils.StringSet {
if d.Tenants == nil {
d.Tenants = utils.NewStringSet()
}
return d.Tenants
}
func (d *MockedSecurityDetails) LocaleCode() string {
return valueFromMap[string](d.KVs, "LocaleCode")
}
func (d *MockedSecurityDetails) CurrencyCode() string {
return valueFromMap[string](d.KVs, "CurrencyCode")
}
func (d *MockedSecurityDetails) FirstName() string {
return d.UserFirstName
}
func (d *MockedSecurityDetails) LastName() string {
return d.UserLastName
}
func (d *MockedSecurityDetails) Email() string {
return valueFromMap[string](d.KVs, "Email")
}
func valueFromMap[T any](m map[string]interface{}, key string) T {
var zero T
if m == nil {
return zero
}
if v, ok := m[key].(T); ok {
return v
}
return zero
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/idp"
)
type MockedPasswdIdentityProvider struct {
domain string
}
func NewMockedPasswdIdentityProvider(domain string) *MockedPasswdIdentityProvider {
return &MockedPasswdIdentityProvider{
domain: domain,
}
}
func (s MockedPasswdIdentityProvider) AuthenticationFlow() idp.AuthenticationFlow {
return idp.InternalIdpForm
}
func (s MockedPasswdIdentityProvider) Domain() string {
return s.domain
}
type MockedIDPManager struct {
idpPasswd idp.IdentityProvider
}
type IdpManagerMockOptions func(opt *IdpManagerMockOption)
type IdpManagerMockOption struct {
PasswdIDPDomain string
}
func NewMockedIDPManager(opts...IdpManagerMockOptions) *MockedIDPManager {
opt := IdpManagerMockOption{}
for _, fn := range opts {
fn(&opt)
}
return &MockedIDPManager{
idpPasswd: NewMockedPasswdIdentityProvider(opt.PasswdIDPDomain),
}
}
func (m *MockedIDPManager) GetIdentityProvidersWithFlow(ctx context.Context, flow idp.AuthenticationFlow) []idp.IdentityProvider {
switch flow {
case idp.InternalIdpForm:
return []idp.IdentityProvider{m.idpPasswd}
default:
return []idp.IdentityProvider{}
}
}
func (m *MockedIDPManager) GetIdentityProviderByDomain(ctx context.Context, domain string) (idp.IdentityProvider, error) {
switch domain {
case m.idpPasswd.Domain():
return m.idpPasswd, nil
}
return nil, fmt.Errorf("cannot find IDP for domain [%s]", domain)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"encoding/json"
"fmt"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/pkg/errors"
"strings"
"time"
)
const (
MockingPropertiesPrefix = "mocking"
)
type MockingProperties struct {
Accounts MockedPropertiesAccounts `json:"accounts"`
Tenants MockedPropertiesTenants `json:"tenants"`
Clients MockedPropertiesClients `json:"clients"`
FederatedUsers MockedPropertiesFederatedUsers `json:"fed-users"`
}
// BindMockingProperties is a FX provider that bind all mocked properties as MockingProperties.
// All mocked properties should be under the yaml section defined as MockingPropertiesPrefix
// e.g. "mocking.accounts" defines all account mocks
func BindMockingProperties(ctx *bootstrap.ApplicationContext) (MockingProperties, error) {
return MockedPropertiesBinder[MockingProperties]("")(ctx)
}
type MockedProperties[T any] map[string]*T
func (p MockedProperties[T]) Values() []*T {
values := make([]*T, 0, len(p))
for _, v := range p {
values = append(values, v)
}
return values
}
func (p MockedProperties[T]) MapValues() map[string]*T {
return p
}
func (p *MockedProperties[T]) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, (*map[string]*T)(p))
}
// MockedPropertiesBinder returns a FX provider that bind specific mocked properties type from the properties sub-section
// specified by "prefix". The root section prefix is defined by MockingPropertiesPrefix
// e.g. MockedPropertiesBinder[MockedPropertiesAccounts]("accounts"):
//
// The returned binder binds MockedPropertiesAccounts from "mocking.accounts"
func MockedPropertiesBinder[T any](prefix string) func(ctx *bootstrap.ApplicationContext) (T, error) {
return func(ctx *bootstrap.ApplicationContext) (T, error) {
prefix = MockingPropertiesPrefix + "." + prefix
prefix = strings.Trim(prefix, ".")
var props T
if err := ctx.Config().Bind(&props, prefix); err != nil {
return props, errors.Wrap(err, fmt.Sprintf("failed to bind mocking properties %T from [%s]", props, prefix))
}
return props, nil
}
}
type MockedPropertiesClients struct {
MockedProperties[MockedClientProperties]
}
type MockedClientProperties struct {
ClientID string `json:"id"`
Secret string `json:"secret"`
GrantTypes utils.CommaSeparatedSlice `json:"grant-types"`
Scopes utils.CommaSeparatedSlice `json:"scopes"`
AutoApproveScopes utils.CommaSeparatedSlice `json:"auto-approve-scopes"`
RedirectUris utils.CommaSeparatedSlice `json:"redirect-uris"`
ATValidity utils.Duration `json:"access-token-validity"`
RTValidity utils.Duration `json:"refresh-token-validity"`
AssignedTenantIds utils.CommaSeparatedSlice `json:"tenants"`
}
type MockedPropertiesAccounts struct {
MockedProperties[MockedAccountProperties]
}
type MockedAccountProperties struct {
UserId string `json:"id"` // optional field
Username string `json:"username"`
Password string `json:"password"`
DefaultTenant string `json:"default-tenant"`
Tenants []string `json:"tenants"`
Perms []string `json:"permissions"`
}
type MockedPropertiesFederatedUsers struct {
MockedProperties[MockedFederatedUserProperties]
}
type MockedFederatedUserProperties struct {
MockedAccountProperties
ExtIdpName string `json:"ext-idp-name"`
ExtIdName string `json:"ext-id-name"`
ExtIdValue string `json:"ext-id-value"`
}
type MockedPropertiesTenants struct {
MockedProperties[MockedTenantProperties]
}
type MockedTenantProperties struct {
ID string `json:"id"` // optional field
ExternalId string `json:"external-id"`
Perms map[string][]string `json:"permissions"` // permissions are MockedAccountProperties.UserId to permissions
}
type scopeMockingProperties struct {
MockingProperties
TokenValidity utils.Duration `json:"token-validity"`
}
func bindScopeMockingProperties(ctx *bootstrap.ApplicationContext) *scopeMockingProperties {
props := scopeMockingProperties{
TokenValidity: utils.Duration(120 * time.Second),
}
if err := ctx.Config().Bind(&props, MockingPropertiesPrefix); err != nil {
panic(errors.Wrap(err, "failed to bind mocking properties"))
}
return &props
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
)
const (
MockedProviderID = "test-provider"
MockedProviderName = "test-provider"
)
type MockedProviderStore struct {}
func (s MockedProviderStore) LoadProviderById(_ context.Context, id string) (*security.Provider, error) {
if id != MockedProviderID {
return nil, fmt.Errorf("cannot find provider with id [%s]", id)
}
return &security.Provider{
Id: id,
Name: MockedProviderName,
DisplayName: MockedProviderName,
Description: MockedProviderName,
LocaleCode: "en_US",
NotificationType: "EMAIL",
Email: "admin@cisco.com",
}, nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/integrate/security/seclient"
"github.com/cisco-open/go-lanai/pkg/security"
"time"
)
/*************************
Mocks
*************************/
type mockedAuthClient struct {
*mockedTokenBase
tokenExp time.Duration
}
// ClientCredentials mocked function will accept any clientID as long as it is accompanied
// by a client Secret.
// If the tokenExp is not defined, will default to 3600
func (c *mockedAuthClient) ClientCredentials(ctx context.Context, opts ...seclient.AuthOptions) (*seclient.Result, error) {
opt, err := c.option(opts)
if err != nil {
return nil, err
}
if opt.ClientID != "" || opt.ClientSecret != "" {
return nil, fmt.Errorf("clientID and clientSecret need to be defined")
}
tokenExp := c.tokenExp
if tokenExp == 0 {
tokenExp = 3600
}
now := time.Now()
exp := now.UTC().Add(tokenExp)
return &seclient.Result{
Token: &MockedToken{
MockedTokenInfo: MockedTokenInfo{
Scopes: opt.Scopes,
},
ExpTime: exp,
IssTime: now,
},
}, nil
}
func newMockedAuthClient(base *mockedTokenBase, tokenValidity time.Duration) seclient.AuthenticationClient {
return &mockedAuthClient{
mockedTokenBase: base,
tokenExp: tokenValidity,
}
}
func (c *mockedAuthClient) PasswordLogin(_ context.Context, opts ...seclient.AuthOptions) (*seclient.Result, error) {
opt, e := c.option(opts)
if e != nil {
return nil, e
}
if opt.AccessToken != "" {
return nil, fmt.Errorf("[Mocked Error] access token is not allowed for password login")
}
acct := c.accounts.find(opt.Username, "")
if acct == nil || acct.Password != opt.Password {
return nil, fmt.Errorf("[Mocked Error] username and password don't match")
}
tenant, e := c.resolveTenant(opt, acct)
if e != nil {
return nil, e
}
exp := time.Now().UTC().Add(c.tokenExp)
return &seclient.Result{
Token: c.newMockedToken(acct, tenant, exp, ""),
}, nil
}
func (c *mockedAuthClient) SwitchUser(_ context.Context, opts ...seclient.AuthOptions) (*seclient.Result, error) {
opt, e := c.option(opts)
if e != nil {
return nil, e
}
mt, e := c.parseMockedToken(opt.AccessToken)
if e != nil || mt.UName == "" {
return nil, fmt.Errorf("[Mocked Error] invalid access token")
}
if acct := c.accounts.find(mt.UName, mt.UID); acct == nil || !acct.MockedAccountDetails.Permissions.Has(security.SpecialPermissionSwitchUser) {
return nil, fmt.Errorf("[Mocked Error] switch user not allowed")
}
acct := c.accounts.find(opt.Username, opt.UserId)
if acct == nil {
return nil, fmt.Errorf("[Mocked Error] target user doesn't exists")
}
tenant, e := c.resolveTenant(opt, acct)
if e != nil {
return nil, e
}
exp := time.Now().UTC().Add(c.tokenExp)
return &seclient.Result{
Token: c.newMockedToken(acct, tenant, exp, mt.UName),
}, nil
}
func (c *mockedAuthClient) SwitchTenant(_ context.Context, opts ...seclient.AuthOptions) (*seclient.Result, error) {
opt, e := c.option(opts)
if e != nil {
return nil, e
}
if opt.Username != "" || opt.UserId != "" {
return nil, fmt.Errorf("[Mocked Error] username or userId not allowed in switching tenant")
}
mt, e := c.parseMockedToken(opt.AccessToken)
if e != nil || mt.UName == "" {
return nil, fmt.Errorf("[Mocked Error] invalid access token")
}
acct := c.accounts.find(mt.UName, mt.UID)
if acct == nil || !acct.MockedAccountDetails.Permissions.Has(security.SpecialPermissionSwitchTenant) {
return nil, fmt.Errorf("[Mocked Error] switch tenant not allowed or deleted user")
}
tenant, e := c.resolveTenant(opt, acct)
if e != nil {
return nil, e
}
exp := time.Now().UTC().Add(c.tokenExp)
return &seclient.Result{
Token: c.newMockedToken(acct, tenant, exp, ""),
}, nil
}
func (c *mockedAuthClient) option(opts []seclient.AuthOptions) (*seclient.AuthOption, error) {
opt := seclient.AuthOption{}
for _, fn := range opts {
fn(&opt)
}
if opt.UserId != "" && opt.Username != "" {
return nil, fmt.Errorf("[Mocked Error] username and userId are exclusive")
}
if opt.TenantId != "" && opt.TenantExternalId != "" {
return nil, fmt.Errorf("[Mocked Error] username and userId are exclusive")
}
return &opt, nil
}
func (c *mockedAuthClient) resolveTenant(opt *seclient.AuthOption, acct *MockedAccount) (ret *mockedTenant, err error) {
if opt.TenantId != "" || opt.TenantExternalId != "" {
ret = c.tenants.find(opt.TenantId, opt.TenantExternalId)
} else if acct.DefaultTenant != "" {
ret = c.tenants.find(acct.DefaultTenant, "")
}
if ret == nil {
return nil, fmt.Errorf("[Mocked Error] tenant not specified and default tenant not configured")
}
if !acct.AssignedTenants.Has(ret.ID) && !acct.AssignedTenants.Has(security.SpecialTenantIdWildcard) {
return nil, fmt.Errorf("[Mocked Error] user does not have access to tenant [%s]", ret.ID)
}
return
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/session"
sessioncommon "github.com/cisco-open/go-lanai/pkg/security/session/common"
"github.com/cisco-open/go-lanai/test/webtest"
"github.com/google/uuid"
"net/http"
)
/*******************
Options
*******************/
func SessionID(sessionId string) webtest.RequestOptions {
return func(req *http.Request) {
cookie := http.Cookie{
Name: sessioncommon.DefaultName,
Value: sessionId,
}
req.Header.Set("Cookie", cookie.String())
}
}
/*******************
Mocks
*******************/
var sessionKeyPrincipal = struct{}{}
type MockedSessionStore struct{
Sessions map[string]*session.Session
}
func NewMockedSessionStore() session.Store {
return &MockedSessionStore{
Sessions: map[string]*session.Session{},
}
}
func MockedSessionStoreDecorator(_ session.Store) session.Store {
return NewMockedSessionStore()
}
func (ss *MockedSessionStore) Get(id string, name string) (s *session.Session, err error) {
if id == "" {
return ss.New(name)
}
if s, _ = ss.Sessions[ss.toKey(id, name)]; s == nil {
s, _ = ss.New(name)
}
return
}
func (ss *MockedSessionStore) New(name string) (*session.Session, error) {
s := session.CreateSession(ss, name)
if s != nil {
ss.Sessions[ss.key(s)] = s
}
return s, nil
}
func (ss *MockedSessionStore) Save(s *session.Session) error {
if s != nil {
ss.Sessions[ss.key(s)] = s
}
return nil
}
func (ss *MockedSessionStore) Invalidate(sessions ...*session.Session) error {
for _, s := range sessions {
delete(ss.Sessions, ss.key(s))
}
return nil
}
func (ss *MockedSessionStore) Options() *session.Options {
return &session.Options{
Path: "/",
Domain: "localhost",
}
}
func (ss *MockedSessionStore) ChangeId(s *session.Session) error {
if s != nil {
// Note: we can't actually change ID because session's id is a private field
newId := uuid.New().String()
delete(ss.Sessions, ss.key(s))
ss.Sessions[newId] = s
}
return nil
}
func (ss *MockedSessionStore) AddToPrincipalIndex(principal string, s *session.Session) error {
if s != nil {
s.Set(sessionKeyPrincipal, principal)
}
return nil
}
func (ss *MockedSessionStore) RemoveFromPrincipalIndex(_ string, s *session.Session) error {
if s != nil {
s.Delete(sessionKeyPrincipal)
}
return nil
}
func (ss *MockedSessionStore) FindByPrincipalName(principal string, sessionName string) ([]*session.Session, error) {
//iterate through the set members using default count
var found []*session.Session
for _, s := range ss.Sessions {
if p, ok := s.Get(sessionKeyPrincipal).(string); ok && p == principal && s.Name() == sessionName {
found = append(found, s)
}
}
return found, nil
}
func (ss *MockedSessionStore) InvalidateByPrincipalName(principal, sessionName string) error {
sessions, e := ss.FindByPrincipalName(principal, sessionName)
if e != nil {
return e
}
return ss.Invalidate(sessions...)
}
func (ss *MockedSessionStore) WithContext(_ context.Context) session.Store {
return ss
}
func (ss *MockedSessionStore) toKey(id, name string) string {
return fmt.Sprintf("session-%ss-%ss", id, name)
}
func (ss *MockedSessionStore) key(sess *session.Session) string {
return fmt.Sprintf("session-%ss-%ss", sess.GetID(), sess.Name())
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
)
type MockedTenantStore struct {
idLookup map[string]*mockedTenant
extIdLookup map[string]*mockedTenant
}
func NewMockedTenantStore(props ...*MockedTenantProperties) *MockedTenantStore {
ret := MockedTenantStore{
idLookup: map[string]*mockedTenant{},
extIdLookup: map[string]*mockedTenant{},
}
for _, v := range props {
t := newTenant(v)
if len(t.ExternalId) != 0 {
ret.extIdLookup[t.ExternalId] = t
}
if len(t.ID) != 0 {
ret.idLookup[t.ID] = t
}
}
return &ret
}
func newTenant(props *MockedTenantProperties) *mockedTenant {
return &mockedTenant{
ID: props.ID,
ExternalId: props.ExternalId,
ProviderId: MockedProviderID,
Permissions: props.Perms,
}
}
func (s *MockedTenantStore) LoadTenantById(_ context.Context, id string) (*security.Tenant, error) {
if t, ok := s.idLookup[id]; ok {
return toSecurityTenant(t), nil
}
return nil, fmt.Errorf("cannot find tenant with ID [%s]", id)
}
func (s *MockedTenantStore) LoadTenantByExternalId(_ context.Context, name string) (*security.Tenant, error) {
if t, ok := s.extIdLookup[name]; ok {
return toSecurityTenant(t), nil
}
return nil, fmt.Errorf("cannot find tenant with external ID [%s]", name)
}
func toSecurityTenant(mocked *mockedTenant) *security.Tenant {
return &security.Tenant{
Id: mocked.ID,
ExternalId: mocked.ExternalId,
DisplayName: mocked.ExternalId,
ProviderId: mocked.ProviderId,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"encoding/base64"
"encoding/json"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"strings"
"time"
)
const (
tokenDelimiter = "~"
)
/*************************
Token
*************************/
type MockedTokenInfo struct {
ClientID string `json:"ClientID"`
UName string `json:"UName"`
UID string `json:"UID"`
TID string `json:"TID"`
TExternalId string `json:"TExternalId"`
OrigU string `json:"OrigU"`
Exp int64 `json:"Exp"`
Iss int64 `json:"Iss"`
Scopes []string `json:"Scopes"`
}
// MockedToken implements oauth2.AccessToken
type MockedToken struct {
MockedTokenInfo
Token string
ExpTime time.Time `json:"-"`
IssTime time.Time `json:"-"`
}
func (mt MockedToken) MarshalText() (text []byte, err error) {
if len(mt.Token) != 0 {
return []byte(mt.Token), nil
}
mt.Exp = mt.ExpTime.UnixNano()
mt.Iss = mt.IssTime.UnixNano()
text, err = json.Marshal(mt.MockedTokenInfo)
if err != nil {
return
}
return []byte(base64.StdEncoding.EncodeToString(text)), nil
}
func (mt *MockedToken) UnmarshalText(text []byte) error {
data, e := base64.StdEncoding.DecodeString(string(text))
if e != nil {
return e
}
if e := json.Unmarshal(data, &mt.MockedTokenInfo); e != nil {
return e
}
mt.ExpTime = time.Unix(0, mt.Exp)
mt.IssTime = time.Unix(0, mt.Iss)
return nil
}
func (mt MockedToken) String() string {
vals := []string{mt.UName, mt.UID, mt.TID, mt.TExternalId, mt.OrigU, mt.ExpTime.Format(utils.ISO8601Milliseconds)}
return strings.Join(vals, tokenDelimiter)
}
func (mt *MockedToken) Value() string {
text, e := mt.MarshalText()
if e != nil {
return ""
}
return string(text)
}
func (mt *MockedToken) ExpiryTime() time.Time {
return mt.ExpTime
}
func (mt *MockedToken) Expired() bool {
return !mt.ExpTime.IsZero() && !time.Now().Before(mt.ExpTime)
}
func (mt *MockedToken) Details() map[string]interface{} {
return map[string]interface{}{}
}
func (mt *MockedToken) Type() oauth2.TokenType {
return oauth2.TokenTypeBearer
}
func (mt *MockedToken) IssueTime() time.Time {
return mt.IssTime
}
func (mt *MockedToken) Scopes() utils.StringSet {
return utils.NewStringSet(mt.MockedTokenInfo.Scopes...)
}
func (mt *MockedToken) RefreshToken() oauth2.RefreshToken {
return nil
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"sync"
"time"
)
/*************************
Interface
*************************/
type MockedTokenRevoker interface {
Revoke(value string)
RevokeAll()
}
/*************************
Base
*************************/
// mockedTokenBase implements MockedTokenRevoker, this serves as a base of multiple mock implementation.
// e.g. mockedTokenStoreReader, mockedAuthClient
type mockedTokenBase struct {
sync.RWMutex
accounts *mockedAccounts
tenants *mockedTenants
revoked utils.StringSet
notBefore time.Time
}
func (b *mockedTokenBase) Revoke(value string) {
b.Lock()
defer b.Unlock()
b.revoked.Add(value)
}
func (b *mockedTokenBase) RevokeAll() {
b.Lock()
defer b.Unlock()
b.notBefore = time.Now().UTC()
}
func (b *mockedTokenBase) isTokenRevoked(token *MockedToken, value string) bool {
b.RLock()
defer b.RUnlock()
return token.IssTime.Before(b.notBefore) || b.revoked.Has(value)
}
func (b *mockedTokenBase) newMockedToken(acct *MockedAccount, tenant *mockedTenant, exp time.Time, origUser string) *MockedToken {
return &MockedToken{
MockedTokenInfo: MockedTokenInfo{
UName: acct.MockedAccountDetails.Username,
UID: acct.UserId,
TID: tenant.ID,
TExternalId: tenant.ExternalId,
OrigU: origUser,
},
ExpTime: exp,
IssTime: time.Now().UTC(),
}
}
func (b *mockedTokenBase) parseMockedToken(value string) (*MockedToken, error) {
mt := &MockedToken{}
if e := mt.UnmarshalText([]byte(value)); e != nil {
return nil, e
}
if b.isTokenRevoked(mt, value) {
return nil, fmt.Errorf("[Mocked Error]: token revoked")
}
return mt, nil
}
func (b *mockedTokenBase) newMockedAuth(mt *MockedToken, acct *MockedAccount) oauth2.Authentication {
user := oauth2.NewUserAuthentication(func(opt *oauth2.UserAuthOption) {
opt.Principal = mt.UName
opt.State = security.StateAuthenticated
opt.Permissions = map[string]interface{}{}
for perm := range acct.MockedAccountDetails.Permissions {
opt.Permissions[perm] = true
}
})
details := NewMockedSecurityDetails(func(d *SecurityDetailsMock) {
*d = SecurityDetailsMock{
Username: acct.Username(),
UserId: acct.UserId,
TenantExternalId: mt.TExternalId,
TenantId: mt.TID,
Exp: mt.ExpTime,
Iss: mt.IssTime,
Permissions: acct.MockedAccountDetails.Permissions,
Tenants: acct.AssignedTenants,
OrigUsername: mt.OrigU,
}
})
return oauth2.NewAuthentication(func(opt *oauth2.AuthOption) {
opt.Request = oauth2.NewOAuth2Request(func(opt *oauth2.RequestDetails) {
opt.ClientId = mt.ClientID
if len(mt.ClientID) == 0 {
opt.ClientId = "mock"
}
opt.Approved = true
opt.Scopes = utils.NewStringSet(mt.MockedTokenInfo.Scopes...)
})
opt.Token = mt
opt.UserAuth = user
opt.Details = details
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
)
type mockedTokenStoreReader struct {
*mockedTokenBase
}
// NewMockedTokenStoreReader create a mocked oauth2.TokenStoreReader based on properties,
// The returned reader also implements MockedTokenRevoker
func NewMockedTokenStoreReader(acctsProps map[string]*MockedAccountProperties, tenantProps map[string]*MockedTenantProperties) oauth2.TokenStoreReader {
accounts := newMockedAccounts(acctsProps)
tenants := newMockedTenants(tenantProps)
return &mockedTokenStoreReader{
mockedTokenBase: &mockedTokenBase{
accounts: accounts,
tenants: tenants,
revoked: utils.NewStringSet(),
},
}
}
func newMockedTokenStoreReader(base *mockedTokenBase) oauth2.TokenStoreReader {
return &mockedTokenStoreReader{
mockedTokenBase: base,
}
}
func (r *mockedTokenStoreReader) ReadAuthentication(_ context.Context, tokenValue string, hint oauth2.TokenHint) (oauth2.Authentication, error) {
if hint != oauth2.TokenHintAccessToken {
return nil, fmt.Errorf("[Mocked Error] wrong token hint")
}
mt, e := r.parseMockedToken(tokenValue)
if e != nil {
return nil, fmt.Errorf("[Mocked Error] invalid access token")
}
acct, ok := r.accounts.lookup[mt.UName]
if !ok {
return nil, fmt.Errorf("[Mocked Error] invalid access token")
}
auth := r.newMockedAuth(mt, acct)
return auth, nil
}
func (r *mockedTokenStoreReader) ReadAccessToken(_ context.Context, value string) (oauth2.AccessToken, error) {
mt, e := r.parseMockedToken(value)
if e != nil {
return nil, fmt.Errorf("[Mocked Error] invalid access token")
}
_, ok := r.accounts.lookup[mt.UName]
if !ok {
return nil, fmt.Errorf("[Mocked Error] invalid access token")
}
return mt, nil
}
func (r *mockedTokenStoreReader) ReadRefreshToken(_ context.Context, _ string) (oauth2.RefreshToken, error) {
return nil, fmt.Errorf("ReadRefreshToken is not implemented for mocked token store")
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"embed"
appconfig "github.com/cisco-open/go-lanai/pkg/appconfig/init"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
securityint "github.com/cisco-open/go-lanai/pkg/integrate/security"
"github.com/cisco-open/go-lanai/pkg/integrate/security/scope"
"github.com/cisco-open/go-lanai/pkg/integrate/security/seclient"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"go.uber.org/fx"
"time"
)
//var logger = log.New("SEC.Test")
//go:embed test-scopes.yml
var defaultMockingConfigFS embed.FS
/**************************
Options
**************************/
// WithMockedScopes is a test.Options that initialize github.com/cisco-open/go-lanai/pkg/integrate/security/scope
// This option configure mocked security scopes based on yaml provided as embed.FS.
// If no config is provided, the default config is used
func WithMockedScopes(mocksConfigFS ...embed.FS) test.Options {
fxOpts := make([]fx.Option, len(mocksConfigFS), len(mocksConfigFS) + 3)
for i, fs := range mocksConfigFS {
fxOpts[i] = appconfig.FxEmbeddedApplicationAdHoc(fs)
}
fxOpts = append(fxOpts,
appconfig.FxEmbeddedBootstrapAdHoc(defaultMockingConfigFS),
fx.Provide(securityint.BindSecurityIntegrationProperties),
fx.Provide(ProvideScopeMocks),
)
opts := []test.Options{
apptest.WithModules(scope.Module),
apptest.WithFxOptions(fxOpts...),
}
return func(opt *test.T) {
for _, fn := range opts {
fn(opt)
}
}
}
/**************************
fx options
**************************/
type MocksDIOut struct {
fx.Out
AuthClient seclient.AuthenticationClient
TokenReader oauth2.TokenStoreReader
TokenRevoker MockedTokenRevoker
}
// ProvideScopeMocks is for internal usage. Exported for cross-package reference
// Try use WithMockedScopes instead
func ProvideScopeMocks(ctx *bootstrap.ApplicationContext) MocksDIOut {
props := bindScopeMockingProperties(ctx)
accounts := newMockedAccounts(props.Accounts.MapValues())
tenants := newMockedTenants(props.Tenants.MapValues())
base := mockedTokenBase{
accounts: accounts,
tenants: tenants,
revoked: utils.NewStringSet(),
}
return MocksDIOut{
AuthClient: newMockedAuthClient(&base, time.Duration(props.TokenValidity)),
TokenReader: newMockedTokenStoreReader(&base),
TokenRevoker: &base,
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package sectest
import (
"context"
"github.com/cisco-open/go-lanai/pkg/security"
"github.com/cisco-open/go-lanai/pkg/security/oauth2"
"github.com/cisco-open/go-lanai/pkg/utils"
"github.com/cisco-open/go-lanai/pkg/web/template"
)
func init() {
template.RegisterGlobalModelValuer(template.ModelKeySecurity, template.ContextModelValuer(security.Get))
}
/**************************
Function
**************************/
type SecurityContextOptions func(opt *SecurityContextOption)
type SecurityContextOption struct {
// Authentication override any other fields
Authentication security.Authentication
}
// ContextWithSecurity set given SecurityContextOptions in the given context, returning the new context
func ContextWithSecurity(ctx context.Context, opts ...SecurityContextOptions) context.Context {
opt := SecurityContextOption{}
for _, fn := range opts {
fn(&opt)
}
// We force a new utils.MutableContext is created,
// in order to preserve any security context in the original "ctx"
testCtx := utils.NewMutableContext(ctx)
security.MustSet(testCtx, opt.Authentication)
return testCtx
}
// WithMockedSecurity used to mock an oauth2.Authentication in the given context, returning a new context
// Deprecated: use ContextWithSecurity(ctx, MockedAuthentication(opts...)) instead
func WithMockedSecurity(ctx context.Context, opts ...SecurityMockOptions) context.Context {
return ContextWithSecurity(ctx, MockedAuthentication(opts...))
}
/**************************
Options
**************************/
// Authentication provides a SecurityContextOptions that sets the authentication to the given value
func Authentication(auth security.Authentication) SecurityContextOptions {
return func(opt *SecurityContextOption) {
opt.Authentication = auth
}
}
// MockedAuthentication provides a SecurityContextOptions that sets the authentication to a mocked oauth2.Authentication
func MockedAuthentication(opts ...SecurityMockOptions) SecurityContextOptions {
return func(opt *SecurityContextOption) {
details := NewMockedSecurityDetails(opts...)
user := oauth2.NewUserAuthentication(func(opt *oauth2.UserAuthOption) {
opt.Principal = details.Username()
opt.State = security.StateAuthenticated
opt.Permissions = map[string]interface{}{}
for perm := range details.Permissions() {
opt.Permissions[perm] = true
}
opt.Details = details.KVs
})
token := &MockedToken{
MockedTokenInfo: MockedTokenInfo{
UName: details.Username(),
UID: details.UserId(),
TID: details.TenantId(),
TExternalId: details.TenantExternalId(),
OrigU: details.OrigUsername,
},
Token: details.AccessToken,
ExpTime: details.Exp,
IssTime: details.Iss,
}
auth := oauth2.NewAuthentication(func(opt *oauth2.AuthOption) {
opt.Request = oauth2.NewOAuth2Request(func(opt *oauth2.RequestDetails) {
opt.ClientId = details.ClientID
opt.Scopes = details.Scopes
opt.Approved = true
opt.GrantType = details.OAuth2GrantType
opt.ResponseTypes = utils.NewStringSetFrom(details.OAuth2ResponseTypes)
for k, v := range details.OAuth2Parameters {
opt.Parameters[k] = v
opt.Extensions[k] = v
}
for k, v := range details.OAuth2Extensions {
opt.Extensions[k] = v
}
})
opt.Token = token
opt.UserAuth = user
opt.Details = details
})
opt.Authentication = auth
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package test
import (
"container/list"
"context"
"fmt"
"github.com/onsi/gomega"
. "github.com/onsi/gomega"
"math/rand"
"path"
"reflect"
"runtime"
"strings"
"testing"
)
/****************************
Sub Tests
****************************/
// SubTestFunc is the function signature for sub-test that taking a context
// and can be registered as SubTest Options
type SubTestFunc func(ctx context.Context, t *testing.T)
// GomegaSubTestFunc is the function signature for sub-test that taking a context and gomega.WithT,
// and can be registered as SubTest Options
type GomegaSubTestFunc func(ctx context.Context, t *testing.T, g *gomega.WithT)
// SubTestFuncWithGomega convert a GomegaSubTestFunc to SubTestFunc
func SubTestFuncWithGomega(st GomegaSubTestFunc) SubTestFunc {
return func(ctx context.Context, t *testing.T) {
st(ctx, t, NewWithT(t))
}
}
// FuncName returns a name that could potentially used as sub test name
// function panic if given fn is not func
func FuncName(fn interface{}, suffixed bool) string {
fnName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
_, fnName = path.Split(fnName)
// we assume fnName is in format of "<package>[.receiver]<.NamedFunction>[.funcN[<.N>...]]
// what we want is "NamedFunction"
fnName = strings.SplitN(fnName, ".func", 2)[0]
split := strings.Split(fnName, ".")
fnName = split[len(split) - 1]
// remove "-fm"
fnName = strings.TrimSuffix(fnName, "-fm")
if suffixed {
fnName = fnName + "@" + randString(4)
}
return fnName
}
/****************************
Test Options
****************************/
// SubTest is an Options that run a SubTestFunc as given name
func SubTest(subtest SubTestFunc, name string) Options {
return func(opt *T) {
opt.SubTests.Set(name, subtest)
}
}
// AnonymousSubTest is an Options that run a SubTestFunc as generated name
func AnonymousSubTest(st SubTestFunc) Options {
return SubTest(st, FuncName(st, true))
}
// GomegaSubTest is an Options that run a GomegaSubTestFunc as given name. If name is not given, a generated name is used
// Note: when name is given as multiple arguments, the first element is used as format and the rest is used as args:
// fmt.Sprintf(name[0], name[1:])
func GomegaSubTest(st GomegaSubTestFunc, name ...string) Options {
var n string
if len(name) > 0 {
args := make([]interface{}, len(name)-1)
for i, v := range name[1:] {
args[i] = v
}
n = fmt.Sprintf(name[0], args...)
} else {
n = FuncName(st, true)
}
return SubTest(SubTestFuncWithGomega(st), n)
}
// SubTestSetup is an Options that register a SetupFunc to run before each sub test
func SubTestSetup(fn SetupFunc) Options {
return func(opt *T) {
opt.SubTestHooks = append(opt.SubTestHooks, &orderedHook{
setupFunc: fn,
})
}
}
// SubTestTeardown is an Options that register a TeardownFunc to run after each sub test
func SubTestTeardown(fn TeardownFunc) Options {
return func(opt *T) {
opt.SubTestHooks = append(opt.SubTestHooks, &orderedHook{
teardownFunc: fn,
})
}
}
/****************************
SubTest List
****************************/
type subTestEntry struct {
name string
fn SubTestFunc
}
// SubTestOrderedMap adopted from https://github.com/elliotchance/orderedmap/blob/master/orderedmap.go
// with reduced functionality
type SubTestOrderedMap struct {
kv map[string]*list.Element
ll *list.List
}
func NewSubTestOrderedMap() *SubTestOrderedMap {
return &SubTestOrderedMap{
kv: make(map[string]*list.Element),
ll: list.New(),
}
}
func (m *SubTestOrderedMap) Len() int {
return len(m.kv)
}
func (m *SubTestOrderedMap) Get(key string) (SubTestFunc, bool) {
if v, ok := m.kv[key]; ok {
return v.Value.(*subTestEntry).fn, true
}
return nil, false
}
func (m *SubTestOrderedMap) Set(name string, fn SubTestFunc) bool {
_, didExist := m.kv[name]
if !didExist {
element := m.ll.PushBack(&subTestEntry{name, fn})
m.kv[name] = element
} else {
m.kv[name].Value.(*subTestEntry).fn = fn
}
return !didExist
}
func (m *SubTestOrderedMap) Keys() (keys []string) {
keys = make([]string, m.Len())
element := m.ll.Front()
for i := 0; element != nil; i++ {
keys[i] = element.Value.(*subTestEntry).name
element = element.Next()
}
return keys
}
// Delete will remove a name from the map. It will return true if the name was
// removed (the name did exist).
func (m *SubTestOrderedMap) Delete(key string) (didDelete bool) {
element, ok := m.kv[key]
if ok {
m.ll.Remove(element)
delete(m.kv, key)
}
return ok
}
/****************************
Helpers
****************************/
func randString(length int) string {
const charset ="0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
b := make([]byte, length)
for i := range b {
//nolint:gosec // We can't use utils package here.
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package suitetest
import (
"fmt"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"github.com/cisco-open/go-lanai/test"
"os"
"sort"
"testing"
)
const (
HookOrderPackage int = - 0xffff
)
type PackageHook interface {
Setup() error
Teardown() error
}
type PackageOptions func(opt *pkg)
type pkg struct {
PackageHooks []PackageHook
TestOptions []test.Options
}
func RunTests(m *testing.M, opts ...PackageOptions) {
s := pkg{
PackageHooks: []PackageHook{},
TestOptions: []test.Options{},
}
for _, fn := range opts {
fn(&s)
}
sort.SliceStable(s.PackageHooks, func(i, j int) bool {
return order.OrderedFirstCompare(s.PackageHooks[i], s.PackageHooks[j])
})
// run setup TestHooks
for _, h := range s.PackageHooks {
if e := h.Setup(); e != nil {
panic(fmt.Errorf("error when setup test pkg: %v", e))
}
}
// register DefaultTestHook
test.InternalOptions = s.TestOptions
// run tests
code := m.Run()
// run teardown TestHooks in reversed order
for i := len(s.PackageHooks) - 1; i >= 0; i-- {
if e := s.PackageHooks[i].Teardown(); e != nil {
panic(fmt.Errorf("error when teardown test pkg: %v", e))
}
}
os.Exit(code)
}
/****************************
Suite Options
****************************/
// SetupFunc is package level setup function that run once per package
type SetupFunc func() error
// TeardownFunc is package level teardown function that run once per package
type TeardownFunc func() error
// orderedSuiteHook implements PackageHook and order.Ordered
type orderedSuiteHook struct {
order int
setupFunc SetupFunc
teardownFunc TeardownFunc
}
func (h *orderedSuiteHook) Order() int {
return h.order
}
func (h *orderedSuiteHook) Setup() error {
if h.setupFunc == nil {
return nil
}
return h.setupFunc()
}
func (h *orderedSuiteHook) Teardown() error {
if h.teardownFunc == nil {
return nil
}
return h.teardownFunc()
}
// WithOptions group multiple PackageOptions into one, typically used for other test utilities to provide
// single entry point of certain feature.
// Not recommended for test implementers to use directly
func WithOptions(opts ...PackageOptions) PackageOptions {
return func(opt *pkg) {
for _, fn := range opts {
fn(opt)
}
}
}
// Setup register the given setup function to run at order 0, higher(smaller) order runs first
// package setup runs once per test package, and should only be registered in TestMain(m *testing.M)
func Setup(fn SetupFunc) PackageOptions {
return SetupWithOrder(0, fn)
}
// SetupWithOrder register the given setup function to run at given order, higher(smaller) order runs first
// package setup runs once per test package, and should only be registered in TestMain(m *testing.M)
func SetupWithOrder(order int, fn SetupFunc) PackageOptions {
return func(opt *pkg) {
opt.PackageHooks = append(opt.PackageHooks, &orderedSuiteHook{
order: order,
setupFunc: fn,
})
}
}
// Teardown register the given teardown function to run at order 0, higher(smaller) order runs LAST
// package teardown runs once per test package, and should only be registered in TestMain(m *testing.M)
func Teardown(fn TeardownFunc) PackageOptions {
return TeardownWithOrder(0, fn)
}
// TeardownWithOrder register the given teardown function to run at given order, higher(smaller) order runs LAST
// package teardown runs once per test package, and should only be registered in TestMain(m *testing.M)
func TeardownWithOrder(order int, fn TeardownFunc) PackageOptions {
return func(opt *pkg) {
opt.PackageHooks = append(opt.PackageHooks, &orderedSuiteHook{
order: order,
teardownFunc: fn,
})
}
}
// TestOptions register per-test options at package level: only declared once in TestMain(m *testing.M)
// All test.Options are applied once per Test*()
func TestOptions(opts ...test.Options) PackageOptions {
return func(opt *pkg) {
opt.TestOptions = append(opt.TestOptions, opts...)
}
}
// TestSetup is a convenient function equivalent to TestOptions(test.Setup(fn))
func TestSetup(fn test.SetupFunc) PackageOptions {
return TestSetupWithOrder(HookOrderPackage, fn)
}
// TestSetupWithOrder is a convenient function equivalent to TestOptions(test.Hooks(test.NewSetupHook(order, fn)))
func TestSetupWithOrder(order int, fn test.SetupFunc) PackageOptions {
return TestOptions(test.Hooks(test.NewSetupHook(order, fn)))
}
// TestTeardown is a convenient function equivalent to TestOptions(test.Teardown(fn))
func TestTeardown(fn test.TeardownFunc) PackageOptions {
return TestTeardownWithOrder(HookOrderPackage, fn)
}
// TestTeardownWithOrder is a convenient function equivalent to TestOptions(test.Hooks(test.NewTeardownHook(order, fn)))
func TestTeardownWithOrder(order int, fn test.TeardownFunc) PackageOptions {
return TestOptions(test.Hooks(test.NewTeardownHook(order, fn)))
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package test
import (
"context"
"github.com/cisco-open/go-lanai/pkg/utils/order"
"sort"
"testing"
)
var (
// InternalOptions is internal variable, exported for cross-package access
// InternalOptions holds common setup/teardown hooks of all tests in same package.
// testsuite package has options to set this list
// Note, when executing all tests, golang run tests on per-package basis
InternalOptions = make([]Options, 0)
)
// InternalRunner is an internal type, exported for cross-package reference
type InternalRunner func(context.Context, *T)
type SetupFunc func(ctx context.Context, t *testing.T) (context.Context, error)
type TeardownFunc func(ctx context.Context, t *testing.T) error
// Hook is registered for tests and sub tests, should provide SetupFunc or TeardownFunc (or both)
// This interface is mostly internal usage.
// Test implementers typically use Options to create instance of this interface
type Hook interface {
Setup(ctx context.Context, t *testing.T) (context.Context, error)
Teardown(ctx context.Context, t *testing.T) error
}
// Options are test config functions to pass into RunTest
type Options func(opt *T)
// T embed *testing.T and holds additional information of test config
type T struct {
*testing.T
runner InternalRunner
TestHooks []Hook
SubTestHooks []Hook
SubTests *SubTestOrderedMap
}
// RunTest is the entry point of any Test...().
// It takes any context, and run sub tests according to provided Options
func RunTest(ctx context.Context, t *testing.T, opts ...Options) {
test := T{
T: t,
runner: unitTestRunner,
TestHooks: []Hook{},
SubTestHooks: []Hook{},
SubTests: NewSubTestOrderedMap(),
}
for _, fn := range InternalOptions {
fn(&test)
}
for _, fn := range opts {
fn(&test)
}
sort.SliceStable(test.TestHooks, func(i, j int) bool {
return order.OrderedFirstCompare(test.TestHooks[i], test.TestHooks[j])
})
test.runner(ctx, &test)
}
func unitTestRunner(ctx context.Context, t *T) {
// run setup TestHooks
ctx = runTestSetupHooks(ctx, t.T, t.TestHooks, "error when setup test")
defer runTestTeardownHooks(ctx, t.T, t.TestHooks, "error when cleanup test")
// run test
InternalRunSubTests(ctx, t)
}
// InternalRunSubTests is an internal function. exported for cross-package reference
func InternalRunSubTests(ctx context.Context, t *T) {
names := t.SubTests.Keys()
for _, n := range names {
if fn, ok := t.SubTests.Get(n); ok {
t.Run(n, func(goT *testing.T) {
ctx = runTestSetupHooks(ctx, goT, t.SubTestHooks, "error when setup sub test")
defer runTestTeardownHooks(ctx, goT, t.SubTestHooks, "error when cleanup sub test")
fn(ctx, goT)
})
}
}
}
func runTestSetupHooks(ctx context.Context, t *testing.T, hooks []Hook, errMsg string) context.Context {
// run setup TestHooks
for _, h := range hooks {
var e error
ctx, e = h.Setup(ctx, t)
if e != nil {
t.Fatalf("%s: %v", errMsg, e)
}
}
return ctx
}
func runTestTeardownHooks(ctx context.Context, t *testing.T, hooks []Hook, errMsg string) {
for _, h := range hooks {
if e := h.Teardown(ctx, t); e != nil {
t.Fatalf("%s: %v", errMsg, e)
}
}
}
//func setupCleanup(t *testing.T, hooks []Hook, errMsg string) {
// // register cleanup
// for _, h := range hooks {
// fn := h.Teardown
// t.Cleanup(func() {
// if e := fn(t); e != nil {
// t.Fatalf("%s: %v", errMsg, e)
// }
// })
// }
//}
// orderedHook implements Hook and order.Ordered
type orderedHook struct {
order int
setupFunc SetupFunc
teardownFunc TeardownFunc
}
func NewHook(order int, setupFunc SetupFunc, teardownFunc TeardownFunc) *orderedHook {
return &orderedHook{
order: order,
setupFunc: setupFunc,
teardownFunc: teardownFunc,
}
}
func NewSetupHook(order int, setupFunc SetupFunc) *orderedHook {
return NewHook(order, setupFunc, nil)
}
func NewTeardownHook(order int, teardownFunc TeardownFunc) *orderedHook {
return NewHook(order, nil, teardownFunc)
}
func (h *orderedHook) Order() int {
return h.order
}
func (h *orderedHook) Setup(ctx context.Context, t *testing.T) (context.Context, error) {
if h.setupFunc == nil {
return ctx, nil
}
return h.setupFunc(ctx, t)
}
func (h *orderedHook) Teardown(ctx context.Context, t *testing.T) error {
if h.teardownFunc == nil {
return nil
}
return h.teardownFunc(ctx, t)
}
/****************************
Common Test Options
****************************/
// WithInternalRunner is internal option, exported for cross-platform access
func WithInternalRunner(runner InternalRunner) Options {
return func(opt *T) {
opt.runner = runner
}
}
// WithOptions group multiple options into one.
// This is mostly used by other testing utilities to provide grouped test configs
func WithOptions(opts ...Options) Options {
return func(opt *T) {
for _, fn := range opts {
fn(opt)
}
}
}
// Setup is an Options that register the SetupFunc to run before ANY sub tests starts
func Setup(fn SetupFunc) Options {
return func(opt *T) {
opt.TestHooks = append(opt.TestHooks, NewSetupHook(0, fn))
}
}
// Teardown is an Options that register the TeardownFunc to run after ALL sub tests finishs
func Teardown(fn TeardownFunc) Options {
return func(opt *T) {
opt.TestHooks = append(opt.TestHooks, NewTeardownHook(0, fn))
}
}
// Hooks is an Options that register multiple Hook.
// Test implementers are recommended to use Setup or Teardown instead
func Hooks(hooks ...Hook) Options {
return func(opt *T) {
opt.TestHooks = append(opt.TestHooks, hooks...)
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package golden will contain some utility functions for golden file testing
//
// Golden File Testing pattern explained here:
// https://ieftimov.com/posts/testing-in-go-golden-files/
//
// # PopulateGoldenFiles will need to be added to the first test run and then removed
//
// Golden Files are populated and asserted based on the current runs test name
// t should be of a type *testing.T ref:[https://pkg.go.dev/testing#T]
// TODO this package has many limitations, e.g. only works with JSON and Structs, and it's not currently used by anyone.
// Consider to remove it or improve it
package golden
import (
"encoding/json"
"github.com/google/go-cmp/cmp"
"github.com/iancoleman/strcase"
"github.com/sergi/go-diff/diffmatchpatch"
"os"
"path/filepath"
"reflect"
"strings"
)
const (
MarshalPrefix = ""
MarshalIndent = " "
)
type GoldenFileTestingT interface {
Fatalf(format string, args ...any)
Errorf(format string, args ...any)
Name() string
}
// PopulateGoldenFiles will write golden files to the according path returned from
// the GetGoldenFilePath function. The function will marshal the data into JSON.
// data should be of a type struct and not []byte or string.
// TODO review this function: if the function fails the test at beginning, what's the point to have it?
func PopulateGoldenFiles(t GoldenFileTestingT, data interface{}) {
t.Errorf("Running PopulateGoldenFiles will result in a failed test.")
if reflect.ValueOf(data).Kind() != reflect.Struct {
t.Fatalf("expected data to be of type struct and not of type: %v", reflect.ValueOf(data).Kind())
}
goldenFilePath := GetGoldenFilePath(t)
b, err := json.MarshalIndent(data, MarshalPrefix, MarshalIndent)
if err != nil {
t.Fatalf("unable to marshal to json: %v", err)
}
if _, err := os.Stat(goldenFilePath); err == nil {
t.Fatalf("cannot use this function to update golden files")
}
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
if err != nil {
t.Fatalf("unable to mkdir to golden file path")
}
err = os.WriteFile(goldenFilePath, b, 0600)
if err != nil {
t.Fatalf("unable to write golden file: %v", err)
}
}
// GetGoldenFilePath will typically return the path in the form ./testdata/golden/<sub-test-name>/<table_driven_test_name>.json
// However, if the test is not run in a subtest or table driven test, the path may differ. However, the last portion
// of the path will always become the golden json name.
func GetGoldenFilePath(t GoldenFileTestingT) string {
fullName := t.Name()
splitName := strings.Split(fullName, "/")
// we expect 3 parts. TestName, SubTest, TableDrivenTest
goldenFilePath := filepath.Join("testdata", "golden")
for i, part := range splitName {
if i == len(splitName)-1 {
// if this is the last part, use it as the .json
part = strcase.ToSnake(part)
goldenFilePath = filepath.Join(goldenFilePath, part+".json")
break
}
goldenFilePath = filepath.Join(goldenFilePath, part)
}
return goldenFilePath
}
// Assert will assert that the data matches what is in the golden file.
// data should be of a type struct and not []byte or string. The function will
// marshal the data into JSON.
// The diff will be represented in a colored diff
func Assert(t GoldenFileTestingT, data interface{}) {
if reflect.ValueOf(data).Kind() != reflect.Struct {
t.Fatalf("expected data to be of type struct")
}
goldenData, err := os.ReadFile(GetGoldenFilePath(t))
if err != nil {
t.Fatalf("unable to read golden file: %v", err)
}
dataJSON, err := json.MarshalIndent(data, "", " ")
if err != nil {
t.Fatalf("unable to marshal to json: %v", err)
}
if !cmp.Equal(goldenData, dataJSON) {
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(string(goldenData), string(dataJSON), false)
dmp.PatchMake()
t.Errorf("[red] missing, [green] extra:\n%v", dmp.DiffPrettyText(diffs))
}
}
package gomegautils
import (
"errors"
"fmt"
errorutils "github.com/cisco-open/go-lanai/pkg/utils/error"
"github.com/onsi/gomega/format"
"github.com/onsi/gomega/types"
)
// IsError returns a types.GomegaMatcher that matches specified error. If the expected error is an errorutils.ErrorCoder
// the code and code mask is reported in the failure message
func IsError(expected error) types.GomegaMatcher {
var code int64
var errCoder errorutils.ErrorCoder
if errors.As(expected, &errCoder) {
code = errCoder.Code()
}
var mask int64
var errCodeMask errorutils.ComparableErrorCoder
if errors.As(expected, &errCodeMask) {
mask = errCodeMask.CodeMask()
}
return &GomegaErrorMatcher{
error: expected,
code: code,
mask: mask,
}
}
// HaveErrorTypeCode returns a types.GomegaMatcher that matches errorutils.CodedError with given top-level type code
// the code and code mask is reported in the failure message
func HaveErrorTypeCode(typeCode int64) types.GomegaMatcher {
return &GomegaErrorMatcher{
error: errorutils.NewErrorType(typeCode, "error type"),
code: typeCode,
mask: errorutils.ErrorTypeMask,
}
}
// HaveErrorSubTypeCode returns a types.GomegaMatcher that matches errorutils.CodedError with given sub-type code
// the code and code mask is reported in the failure message
func HaveErrorSubTypeCode(typeCode int64) types.GomegaMatcher {
return &GomegaErrorMatcher{
error: errorutils.NewErrorSubType(typeCode, "error sub-type"),
code: typeCode,
mask: errorutils.ErrorSubTypeMask,
}
}
// HaveErrorCode returns a types.GomegaMatcher that matches errorutils.CodedError with given code
// the code and code mask is reported in the failure message
func HaveErrorCode(typeCode int64) types.GomegaMatcher {
return &GomegaErrorMatcher{
error: errorutils.NewCodedError(typeCode, "coded error"),
code: typeCode,
mask: errorutils.DefaultErrorCodeMask,
}
}
// GomegaErrorMatcher implements types.GomegaMatcher for error type
type GomegaErrorMatcher struct {
error error
code int64
mask int64
}
func (m *GomegaErrorMatcher) Match(actual interface{}) (success bool, err error) {
actualErr, ok := actual.(error)
if !ok {
return false, fmt.Errorf(`%T is not an error`, actual)
}
return errors.Is(actualErr, m.error), nil
}
func (m *GomegaErrorMatcher) FailureMessage(actual interface{}) (message string) {
var msg string
if m.code != 0 {
msg = fmt.Sprintf("to be an error with type [%T], code [%#016x] and mask [%#016x]", m.error, uint64(m.code), uint64(m.mask))
} else {
msg = fmt.Sprintf(`to equals to error [%T] - %v`, m.error, m.error)
}
return fmt.Sprintf("Expected\n%s\n%s\n", m.formatActual(actual), msg)
}
func (m *GomegaErrorMatcher) NegatedFailureMessage(actual interface{}) (message string) {
var msg string
if m.code != 0 {
msg = fmt.Sprintf("not to be an error with type [%T], code [%#016x] and mask [%#016x]", m.error, uint64(m.code), uint64(m.mask))
} else {
msg = fmt.Sprintf(`to not equal to error [%T] - %v`, m.error, m.error)
}
return fmt.Sprintf("Expected\n%s\n%s\n", m.formatActual(actual), msg)
}
func (m *GomegaErrorMatcher) formatActual(actual interface{}) interface{} {
desc := format.Object(actual, 1)
actualErr, ok := actual.(error)
if !ok {
return desc
}
var coder errorutils.ErrorCoder
if errors.As(actualErr, &coder) {
desc = desc + fmt.Sprintf(" <Code %#016x>", uint64(coder.Code()))
}
return desc
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package gomegautils
import (
"fmt"
"github.com/onsi/gomega"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/format"
"github.com/onsi/gomega/types"
"github.com/spyzhov/ajson"
"strings"
)
/****************************
Common Gomega Matchers
****************************/
// HaveJsonPathWithValue returns a gomega matcher.
// This matcher extract fields from JSON string using JSONPath, and assert that if the result slice matches the expected value
//
// "value" can be following types:
// - types.GomegaMatcher, then the given matcher is directly applied to the slice resulted from JSONPath searching
// - any non-matcher type, HaveJsonPathWithValue by default use gomega.ContainElements(gomega.Equal(expected)) on any non-matcher value
//
// Following statements are equivalent:
// Expect(jsonStr).To(HaveJsonPathWithValue("$..status", "GOOD"))
// Expect(jsonStr).To(HaveJsonPathWithValue("$..status", gomega.ContainElements(gomega.Equal("GOOD"))))
func HaveJsonPathWithValue(jsonPath string, value interface{}) types.GomegaMatcher {
var matcher types.GomegaMatcher
switch v := value.(type) {
case types.GomegaMatcher:
matcher = v
default:
matcher = gomega.ContainElements(gomega.Equal(v))
}
return &GomegaJsonPathMatcher{
jsonPath: jsonPath,
delegate: matcher,
}
}
// HaveJsonPath returns a gomega matcher, similar to HaveJsonPathWithValue
// HaveJsonPath succeed only if the specified JSONPath yield non-empty result from actual JSON string.
//
// Following statements are equivalent:
// Expect(jsonStr).To(HaveJsonPath("$..status"))
// Expect(jsonStr).To(HaveJsonPath("$..status", gomega.Not(gomega.BeEmpty)))
func HaveJsonPath(jsonPath string) types.GomegaMatcher {
return &GomegaJsonPathMatcher{
jsonPath: jsonPath,
delegate: Not(BeEmpty()),
}
}
type GomegaJsonPathMatcher struct {
jsonPath string
delegate types.GomegaMatcher
}
func (m *GomegaJsonPathMatcher) Match(actual interface{}) (success bool, err error) {
values, e := m.jsonPathValues(actual)
if e != nil {
return false, e
}
return m.delegate.Match(values)
}
func (m *GomegaJsonPathMatcher) FailureMessage(actual interface{}) (message string) {
msg := fmt.Sprintf("to have JsonPath %s matching", m.jsonPath)
desc := format.Message(strings.TrimSpace(asString(actual)), msg, m.delegate)
actual, _ = m.jsonPathValues(actual)
return fmt.Sprintf("%s\nResult:\n%s", desc, m.delegate.FailureMessage(actual))
}
func (m *GomegaJsonPathMatcher) NegatedFailureMessage(actual interface{}) (message string) {
msg := fmt.Sprintf("to have JsonPath %s not matching", m.jsonPath)
desc := format.Message(strings.TrimSpace(asString(actual)), msg, m.delegate)
actual, _ = m.jsonPathValues(actual)
return fmt.Sprintf("%s\nResult:\n%s", desc, m.delegate.NegatedFailureMessage(actual))
}
func (m *GomegaJsonPathMatcher) jsonPathValues(actual interface{}) ([]interface{}, error) {
var data []byte
switch v := actual.(type) {
case string:
data = []byte(v)
case []byte:
data = v
default:
return nil, fmt.Errorf("expect string or []byte, but got %T", actual)
}
root, e := ajson.Unmarshal(data)
if e != nil {
return nil, fmt.Errorf(`expect json string but got %T`, actual)
}
parsed, e := ajson.ParseJSONPath(m.jsonPath)
if e != nil {
return nil, fmt.Errorf("invalid JSONPath '%s'", m.jsonPath)
}
nodes, e := ajson.ApplyJSONPath(root, parsed)
if e != nil {
return nil, fmt.Errorf(`invalid JsonPath "%s"`, m.jsonPath)
}
values := make([]interface{}, len(nodes))
for i, node := range nodes {
var e error
if values[i], e = node.Unpack(); e != nil {
return nil, fmt.Errorf(`unable to extract value of JsonPath [%s]: %v'`, m.jsonPath, e)
}
}
return values, nil
}
/****************************
Gomega Matchers Helpers
****************************/
func asString(actual interface{}) string {
var data string
switch v := actual.(type) {
case string:
data = v
case []byte:
data = string(v)
}
return data
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testutils
import (
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"
)
var RuntimeTest func(t *testing.T)
// PackageDirectory returns package path of the test's Test...() function.
// This function works as following:
// It traces back the call stack using runtime,
// ignoring test utility packages like "test", "test/utils", "testdata" and "runtime,
// until it find golang's "testing" package. The last seen package is considered the package of Test...().
//
// Limitation:
// - This function would not work in tests that have their directory ending with "test" and "test/utils"
// The workaround is to set RuntimeTest to one of the test function.
// - THis function would not work in any "testdata" directory. There is no workaround.
func PackageDirectory() string {
rpc := make([]uintptr, 10)
if n := runtime.Callers(1, rpc[:]); n < 1 {
panic("unable find package path")
}
var lastPkgDir string
frames := runtime.CallersFrames(rpc)
LOOP:
for frame, more := frames.Next(); more; frame, more = frames.Next() {
dir := filepath.Dir(frame.File)
switch {
case strings.HasSuffix(dir, "testing"):
break LOOP
case strings.HasSuffix(dir, "test"):
fallthrough
case strings.HasSuffix(dir, "test/utils"):
fallthrough
case strings.HasSuffix(dir, "testdata"):
fallthrough
case strings.HasSuffix(dir, "runtime"):
// Do nothing
default:
lastPkgDir = dir
}
}
if len(lastPkgDir) == 0 {
if RuntimeTest != nil {
frame, _ := runtime.CallersFrames([]uintptr{reflect.ValueOf(RuntimeTest).Pointer()}).Next()
return filepath.Dir(frame.File)
}
panic("unable find package path")
}
return lastPkgDir
}
// ProjectDirectory returns the directory of test's project root (directory containing go.mod).
// Note: this function leverage PackageDirectory(), so same limitation applies
func ProjectDirectory() string {
pkgDir := PackageDirectory()
for dir := pkgDir; dir != "/" && dir != ""; dir = filepath.Clean(filepath.Dir(dir)) {
stat, e := os.Stat(filepath.Join(dir, "go.mod"))
if e == nil && !stat.IsDir() {
return dir
}
}
return ""
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package testutils
import (
"bufio"
"fmt"
"github.com/google/uuid"
"io/fs"
)
// UUIDPool hold a list of uuid and make sure same uuid is returned each time Pop is called
type UUIDPool struct {
Pool []uuid.UUID
Current int
}
// Pop return a uuid and increase index by one. This function returns error when index out of bound
func (p *UUIDPool) Pop() (uuid.UUID, error) {
defer func() {p.Current++}()
if p.Current >= len(p.Pool) {
return uuid.Nil, fmt.Errorf("UUID pool exhausted")
}
return p.Pool[p.Current], nil
}
func (p *UUIDPool) PopOrNew() uuid.UUID {
id, e := p.Pop()
if e != nil {
return uuid.New()
}
return id
}
func NewUUIDPool(fsys fs.FS, src string) (*UUIDPool, error) {
f, e := fsys.Open(src)
if e != nil {
return nil, e
}
defer func() {_ = f.Close()}()
scanner := bufio.NewScanner(f)
scanner.Split(bufio.ScanLines)
pool := make([]uuid.UUID, 0, 32)
for scanner.Scan() {
id, e := uuid.Parse(scanner.Text())
if e == nil {
pool = append(pool, id)
}
}
if len(pool) == 0 {
return nil, fmt.Errorf("unable to load UUIDs")
}
return &UUIDPool{
Pool: pool,
}, nil
}
package webtest
import (
"context"
"github.com/cisco-open/go-lanai/pkg/web"
"github.com/gin-gonic/gin"
"io"
"net/http"
"net/http/httptest"
"sync"
)
const ginContextRecorder = `_go-lanai/webtest/gin/recorder`
type ginContextCreator struct {
sync.Once
engine *gin.Engine
}
func (c *ginContextCreator) LazyInit() {
c.Do(func() {
gin.SetMode(gin.ReleaseMode)
c.engine = gin.New()
c.engine.ContextWithFallback = true
})
}
func (c *ginContextCreator) CreateWithRequest(req *http.Request) *gin.Context {
c.LazyInit()
rw := httptest.NewRecorder()
gc := gin.CreateTestContextOnly(rw, c.engine)
gc.Set(ginContextRecorder, rw)
if req != nil {
gc.Request = req
web.GinContextMerger()(gc)
}
return gc
}
var defaultGinContextCreator = &ginContextCreator{}
func NewGinContext(ctx context.Context, method, path string, body io.Reader, opts ...RequestOptions) *gin.Context {
req := httptest.NewRequest(method, path, body).WithContext(ctx)
for _, fn := range opts {
if fn != nil {
fn(req)
}
}
return NewGinContextWithRequest(req)
}
func NewGinContextWithRequest(req *http.Request) *gin.Context {
return defaultGinContextCreator.CreateWithRequest(req)
}
func GinContextRecorder(gc *gin.Context) *httptest.ResponseRecorder {
recorder, _ := gc.Value(ginContextRecorder).(*httptest.ResponseRecorder)
return recorder
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package webtest
import (
"context"
"github.com/cisco-open/go-lanai/pkg/bootstrap"
"github.com/cisco-open/go-lanai/pkg/web"
"go.uber.org/fx"
)
var mockedWebModule = &bootstrap.Module{
Name: "web",
Precedence: web.MinWebPrecedence,
PriorityOptions: []fx.Option{
fx.Provide(
web.BindServerProperties,
web.NewEngine,
web.NewRegistrar),
fx.Invoke(initialize),
},
}
type initDI struct {
fx.In
Registrar *web.Registrar
Properties web.ServerProperties
Controllers []web.Controller `group:"controllers"`
Customizers []web.Customizer `group:"customizers"`
ErrorTranslators []web.ErrorTranslator `group:"error_translators"`
}
func initialize(lc fx.Lifecycle, di initDI) {
di.Registrar.MustRegister(web.NewLoggingCustomizer(di.Properties))
di.Registrar.MustRegister(web.NewRecoveryCustomizer())
di.Registrar.MustRegister(web.NewGinErrorHandlingCustomizer())
di.Registrar.MustRegister(di.Controllers)
di.Registrar.MustRegister(di.Customizers)
di.Registrar.MustRegister(di.ErrorTranslators)
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) (err error) {
if err = di.Registrar.Initialize(ctx); err != nil {
return
}
defer func(ctx context.Context) {
_ = di.Registrar.Cleanup(ctx)
}(ctx)
return nil
},
})
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package webtest
import (
"context"
"fmt"
"github.com/cisco-open/go-lanai/pkg/log"
"github.com/cisco-open/go-lanai/pkg/web"
webinit "github.com/cisco-open/go-lanai/pkg/web/init"
"github.com/cisco-open/go-lanai/test"
"github.com/cisco-open/go-lanai/test/apptest"
"go.uber.org/fx"
"testing"
)
const (
DefaultContextPath = "/test"
)
// WithRealServer start a real web server at random port with context-path as DefaultContextPath.
// NewRequest(), Exec() and MustExec() can be used to create/send request and verifying result
// By default, the server doesn't allow CORS and have no security configured.
// Actual server port can be retrieved via CurrentPort()
// When using this mode, *web.Registrar became available to inject
func WithRealServer(opts ...TestServerOptions) test.Options {
conf := TestServerConfig{
ContextPath: DefaultContextPath,
LogLevel: log.LevelInfo,
}
for _, fn := range opts {
fn(&conf)
}
props := toProperties(&conf)
di := realSvrDI{}
return test.WithOptions(
apptest.WithModules(webinit.Module),
apptest.WithProperties(props...),
apptest.WithDI(&di),
test.SubTestSetup(testSetupAddrExtractor(&conf, &di)),
)
}
// WithMockedServer initialize web package without starting an actual web server.
// NewRequest(), Exec() and MustExec() can be used to create/send request and verifying result without creating an actual http connection.
// By default, the server doesn't allow CORS and have no security configured.
// When using this mode, *web.Registrar became available to inject
// Note: In this mode, httptest package is used internally and http.Handler (*web.Engine in our case) is invoked directly
func WithMockedServer(opts ...TestServerOptions) test.Options {
conf := TestServerConfig{
ContextPath: DefaultContextPath,
LogLevel: log.LevelInfo,
}
for _, fn := range opts {
fn(&conf)
}
props := toProperties(&conf)
di := mockedSvrDI{}
return test.WithOptions(
apptest.WithModules(mockedWebModule),
apptest.WithProperties(props...),
apptest.WithDI(&di),
test.SubTestSetup(testSetupEngineExtractor(&conf, &di)),
)
}
// WithUtilities DOES NOT initialize web package, it only provide properties and setup utilities (e.g. MustExec)
// Important: this mode is mostly for go-lanai internal tests. DO NOT use it in microservices
//
// NewRequest(), Exec() and MustExec() can be used to create/send request and verifying result without creating an actual http connection.
// Note: In this mode, httptest package is used internally and http.Handler (*web.Engine in our case) is invoked directly
func WithUtilities(opts ...TestServerOptions) test.Options {
conf := TestServerConfig{
ContextPath: DefaultContextPath,
LogLevel: log.LevelInfo,
}
for _, fn := range opts {
fn(&conf)
}
props := toProperties(&conf)
di := mockedSvrDI{}
return test.WithOptions(
apptest.WithProperties(props...),
apptest.WithFxOptions(
fx.Provide(
web.BindServerProperties,
),
),
apptest.WithDI(&di),
test.SubTestSetup(testSetupEngineExtractor(&conf, &di)),
)
}
// UsePort returns a TestServerOptions that use given port.
// Note: using fixed port might cause issues when run in CI/CD
func UsePort(port int) TestServerOptions {
return func(conf *TestServerConfig) {
conf.Port = port
}
}
// UseContextPath returns a TestServerOptions that overwrite the context-path of the test server
func UseContextPath(contextPath string) TestServerOptions {
return func(conf *TestServerConfig) {
conf.ContextPath = contextPath
}
}
// UseLogLevel returns a TestServerOptions that overwrite the default log level of the test server
func UseLogLevel(lvl log.LoggingLevel) TestServerOptions {
return func(conf *TestServerConfig) {
conf.LogLevel = lvl
}
}
// AddDefaultRequestOptions returns a TestServerOptions that add default RequestOptions on every request
// created via NewRequest.
func AddDefaultRequestOptions(opts...RequestOptions) TestServerOptions {
return func(conf *TestServerConfig) {
conf.RequestOptions = append(conf.RequestOptions, opts...)
}
}
type realSvrDI struct {
fx.In
Registrar *web.Registrar
Engine *web.Engine
}
type mockedSvrDI struct {
fx.In
Engine *web.Engine `optional:"true"`
}
func toProperties(conf *TestServerConfig) []string {
return []string{
fmt.Sprintf("server.port: %d", conf.Port),
fmt.Sprintf("server.context-path: %s", conf.ContextPath),
fmt.Sprintf("server.logging.default-level: %s", conf.LogLevel.String()),
"server.logging.enabled: true",
}
}
func testSetupAddrExtractor(conf *TestServerConfig, di *realSvrDI) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
info := serverInfo{
hostname: "127.0.0.1",
port: di.Registrar.ServerPort(),
contextPath: conf.ContextPath,
}
return newWebTestContext(ctx, conf, &info, nil), nil
}
}
func testSetupEngineExtractor(conf *TestServerConfig, di *mockedSvrDI) test.SetupFunc {
return func(ctx context.Context, t *testing.T) (context.Context, error) {
info := serverInfo{
contextPath: conf.ContextPath,
}
return newWebTestContext(ctx, conf, &info, di.Engine), nil
}
}
// Copyright 2023 Cisco Systems, Inc. and its affiliates
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
package webtest
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"path"
)
// CurrentPort utility func that extract current server port from testing context
// Return -1 if not found
func CurrentPort(ctx context.Context) int {
if v, ok := ctx.Value(ctxKeyInfo).(*serverInfo); ok {
return v.port
}
return -1
}
// CurrentContextPath utility func that extract current server context-path from testing context
// Return DefaultContextPath if not found
func CurrentContextPath(ctx context.Context) string {
if v, ok := ctx.Value(ctxKeyInfo).(*serverInfo); ok {
return v.contextPath
}
return DefaultContextPath
}
type RequestOptions func(req *http.Request)
// NewRequest create a new *http.Request based on current execution mode.
// WithRealServer mode:
// - Created request have Host, Port and ContextPath set to current TestServer.
// - If the given target is relative path, "http" is used. and "context path" is prepended to the given path.
// - If the given target is absolute URL, its Host, Port are overridden, and path is kept unchanged
//
// WithMockedServer mode:
// - the returned request is created by `httptest.NewRequest` and cannot be used by http.DefaultClient.Do()
// - If the given target is relative path, "http" is used. and "context path" is prepended to the given path.
// - If the given target is absolute URL, host, port and path are kept unchanged
//
// This function panic if given target is not valid absolute/relative URL or test server is not enabled
func NewRequest(ctx context.Context, method, target string, body io.Reader, opts ...RequestOptions) (req *http.Request) {
tUrl, e := url.Parse(target)
if e != nil {
panic(fmt.Sprintf("invalid request target: %v", e))
}
info, ok := ctx.Value(ctxKeyInfo).(*serverInfo)
if !ok {
panic("invalid use of webtest.NewRequest(). Make sure webtest.WithRealServer() or webtest.WithMockedServer() is in-effect")
}
originalPath := tUrl.Path
if !tUrl.IsAbs() {
tUrl.Scheme = "http"
tUrl.Path = path.Clean(path.Join(info.contextPath, tUrl.Path))
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary;
// make sure it behaves the same as production
if originalPath[len(originalPath)-1] == '/' && tUrl.Path != "/" {
tUrl.Path += "/"
}
}
if ctx.Value(ctxKeyHttpHandler) != nil {
// WithMockedServer is enabled, we use httptest
req = httptest.NewRequest(method, tUrl.String(), body).WithContext(ctx)
} else {
tUrl.Host = fmt.Sprintf("%s:%d", info.hostname, info.port)
req, e = http.NewRequestWithContext(ctx, method, tUrl.String(), body)
if e != nil {
panic(e)
}
}
applyRequestOptions(ctx, req, true, opts)
return
}
// Exec execute given request depending on test server mode (real vs mocked)
// returned ExecResult is guaranteed to have non-nil ExecResult.Response if there is no error.
// ExecResult.ResponseRecorder is non-nil if test server mode is WithMockedServer()
// this func might return error if test server mode is WithRealServer()
// Note: don't forget to close the response's body when done with it
//nolint:bodyclose // we don't close body here, whoever using this function should close it when done
func Exec(ctx context.Context, req *http.Request, opts ...RequestOptions) (ExecResult, error) {
applyRequestOptions(ctx, req, false, opts)
if handler, ok := ctx.Value(ctxKeyHttpHandler).(http.Handler); ok {
// mocked mode
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
return ExecResult{
Response: rw.Result(),
ResponseRecorder: rw,
}, nil
}
// default to real server mode
resp, e := http.DefaultClient.Do(req)
return ExecResult{
Response: resp,
}, e
}
// MustExec is same as Exec, but panic instead of returning error
// Note: don't forget to close the response's body when done with it
func MustExec(ctx context.Context, req *http.Request, opts ...RequestOptions) ExecResult {
ret, e := Exec(ctx, req, opts...)
if e != nil {
panic(e)
}
return ret
}
func applyRequestOptions(ctx context.Context, req *http.Request, withDefaults bool, opts []RequestOptions) {
// extract default request options from context
if withDefaults {
if conf, ok := ctx.Value(ctxKeyConfig).(*TestServerConfig); ok && len(conf.RequestOptions) != 0 {
cpy := append([]RequestOptions{}, conf.RequestOptions...)
opts = append(cpy, opts...)
}
}
// apply all options
for _, fn := range opts {
fn(req)
}
}
/*************************
Options
*************************/
// Headers returns a RequestOptions that set additional headers
func Headers(kvs ...string) RequestOptions {
return func(req *http.Request) {
for i := 0; i < len(kvs); i += 2 {
if i+1 < len(kvs) {
req.Header.Add(kvs[i], kvs[i+1])
} else {
req.Header.Add(kvs[i], "")
}
}
}
}
// Queries returns a RequestOptions that set additional queries
func Queries(kvs ...string) RequestOptions {
return func(req *http.Request) {
q := req.URL.Query()
for i := 0; i < len(kvs); i += 2 {
if i+1 < len(kvs) {
q.Add(kvs[i], kvs[i+1])
} else {
q.Add(kvs[i], "")
}
}
req.URL.RawQuery = q.Encode()
}
}
// Cookies returns a RequestOptions that carry on cookies from given response
func Cookies(resp *http.Response) RequestOptions {
cookies := resp.Cookies()
kvs := make([]string, len(cookies)*2)
for i := range cookies {
kvs[i*2] = "Cookie"
kvs[i*2+1] = cookies[i].String()
}
return Headers(kvs...)
}
// ContentType returns a RequestOptions that set content type in header
func ContentType(v string) RequestOptions {
return Headers("Content-Type", v)
}
/*************************
Custom Context
*************************/
type infoCtxKey struct{}
var ctxKeyInfo = infoCtxKey{}
type configCtxKey struct{}
var ctxKeyConfig = configCtxKey{}
type httpHandlerCtxKey struct{}
var ctxKeyHttpHandler = httpHandlerCtxKey{}
type serverInfo struct {
hostname string
port int
contextPath string
}
type webTestContext struct {
context.Context
info *serverInfo
config *TestServerConfig
handler http.Handler
}
func newWebTestContext(parent context.Context, config *TestServerConfig, info *serverInfo, handler http.Handler) context.Context {
return &webTestContext{
Context: parent,
info: info,
config: config,
handler: handler,
}
}
func (c *webTestContext) Value(key interface{}) interface{} {
switch {
case key == ctxKeyInfo && c.info != nil:
return c.info
case key == ctxKeyConfig && c.config != nil:
{
return c.config
}
case key == ctxKeyHttpHandler && c.handler != nil:
return c.handler
}
return c.Context.Value(key)
}