diff --git a/infra/sidecar.Dockerfile b/infra/sidecar.Dockerfile index a26df79..e46f6e6 100644 --- a/infra/sidecar.Dockerfile +++ b/infra/sidecar.Dockerfile @@ -1,5 +1,5 @@ # ----- Builder image -ARG GOLANG_VERSION=1.23.6 +ARG GOLANG_VERSION=1.24.0 FROM golang:${GOLANG_VERSION}-bookworm AS builder ARG FIPS_MODE diff --git a/splitd.yaml.tpl b/splitd.yaml.tpl index 2158359..9fcc009 100644 --- a/splitd.yaml.tpl +++ b/splitd.yaml.tpl @@ -8,6 +8,10 @@ sdk: apikey: labelsEnabled: true streamingEnabled: true + fallbackTreatment: + global_fallback_treatment: + treatment: other + by_flag_fallback_treatment: {} urls: auth: https://auth.split.io sdk: https://sdk.split.io/api diff --git a/splitio/commitsha.go b/splitio/commitsha.go index dfdefd0..251a44f 100644 --- a/splitio/commitsha.go +++ b/splitio/commitsha.go @@ -1,3 +1,3 @@ package splitio -const CommitSHA = "a651b23" +const CommitSHA = "085f07b" diff --git a/splitio/conf/splitcli.go b/splitio/conf/splitcli.go index b192f9b..01b3e07 100644 --- a/splitio/conf/splitcli.go +++ b/splitio/conf/splitcli.go @@ -25,15 +25,16 @@ type CliArgs struct { WriteTimeoutMS int // command - Method string - Key string - BucketingKey string - Feature string - Features []string - TrafficType string - EventType string - EventVal *float64 - Attributes map[string]interface{} + Method string + Key string + BucketingKey string + Feature string + Features []string + TrafficType string + EventType string + EventVal *float64 + Attributes map[string]interface{} + ImpressionProperties map[string]interface{} } func (a *CliArgs) LinkOpts() (*link.ConsumerOptions, error) { @@ -85,6 +86,7 @@ func ParseCliArgs() (*CliArgs, error) { et := cliFlags.String("event-type", "", "event type") ev := cliFlags.String("value", "", "event associated value") at := cliFlags.String("attributes", "", "json representation of attributes") + pr := cliFlags.String("impression-properties", "", "json representation of") err := cliFlags.Parse(os.Args[1:]) if err != nil { return nil, fmt.Errorf("error parsing arguments: %w", err) @@ -107,22 +109,31 @@ func ParseCliArgs() (*CliArgs, error) { return nil, fmt.Errorf("error parsing attributes: %w", err) } + if *pr == "" { + *pr = "null" + } + impressionPorperties := make(map[string]interface{}) + if err = json.Unmarshal([]byte(*pr), &impressionPorperties); err != nil { + return nil, fmt.Errorf("error parsing impression properties: %w", err) + } + return &CliArgs{ - ID: *id, - Serialization: *s, - Protocol: *p, - LogLevel: *ll, - ConnType: *ct, - ConnAddr: *ca, - BufSize: *bs, - Method: *m, - Key: *k, - BucketingKey: *bk, - Feature: *f, - Features: strings.Split(*fs, ","), - TrafficType: *tt, - EventType: *et, - EventVal: eventVal, - Attributes: attrs, + ID: *id, + Serialization: *s, + Protocol: *p, + LogLevel: *ll, + ConnType: *ct, + ConnAddr: *ca, + BufSize: *bs, + Method: *m, + Key: *k, + BucketingKey: *bk, + Feature: *f, + Features: strings.Split(*fs, ","), + TrafficType: *tt, + EventType: *et, + EventVal: eventVal, + Attributes: attrs, + ImpressionProperties: impressionPorperties, }, nil } diff --git a/splitio/conf/splitcli_test.go b/splitio/conf/splitcli_test.go index 102491b..c7e524a 100644 --- a/splitio/conf/splitcli_test.go +++ b/splitio/conf/splitcli_test.go @@ -25,6 +25,7 @@ func TestCliConfig(t *testing.T) { "-event-type=someEventType", "-value=0.123", `-attributes={"some": "attribute"}`, + `-impression-properties={"userId": "123", "age": 30, "premium": true, "balance": 99.5}`, } parsed, err := ParseCliArgs() @@ -42,6 +43,7 @@ func TestCliConfig(t *testing.T) { assert.Equal(t, "someEventType", parsed.EventType) assert.Equal(t, lang.Ref(float64(0.123)), parsed.EventVal) assert.Equal(t, map[string]interface{}{"some": "attribute"}, parsed.Attributes) + assert.Equal(t, map[string]interface{}{"userId": "123", "age": float64(30), "premium": true, "balance": 99.5}, parsed.ImpressionProperties) // test bad buffer size os.Args = []string{os.Args[0], "-buffer-size=sarasa"} diff --git a/splitio/conf/splitd.go b/splitio/conf/splitd.go index f80d183..e1ea685 100644 --- a/splitio/conf/splitd.go +++ b/splitio/conf/splitd.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/splitio/go-split-commons/v9/dtos" "github.com/splitio/go-toolkit/v5/logging" "github.com/splitio/splitd/splitio/common/lang" "github.com/splitio/splitd/splitio/link" @@ -122,14 +123,15 @@ func (l *Link) ToListenerOpts() (*link.ListenerOptions, error) { } type SDK struct { - Apikey string `yaml:"apikey"` - LabelsEnabled *bool `yaml:"labelsEnabled"` - StreamingEnabled *bool `yaml:"streamingEnabled"` - URLs URLs `yaml:"urls"` - FeatureFlags FeatureFlags `yaml:"featureFlags"` - Impressions Impressions `yaml:"impressions"` - Events Events `yaml:"events"` - FlagSetsFilter []string `yaml:"flagSetsFilter"` + Apikey string `yaml:"apikey"` + LabelsEnabled *bool `yaml:"labelsEnabled"` + StreamingEnabled *bool `yaml:"streamingEnabled"` + FallbackTreatment fallbackTreatmentInput `yaml:"fallbackTreatment"` + URLs URLs `yaml:"urls"` + FeatureFlags FeatureFlags `yaml:"featureFlags"` + Impressions Impressions `yaml:"impressions"` + Events Events `yaml:"events"` + FlagSetsFilter []string `yaml:"flagSetsFilter"` } func (s *SDK) PopulateWithDefaults() { @@ -137,6 +139,7 @@ func (s *SDK) PopulateWithDefaults() { s.Apikey = apikeyPlaceHolder s.LabelsEnabled = lang.Ref(cfg.LabelsEnabled) s.StreamingEnabled = lang.Ref(cfg.StreamingEnabled) + s.FallbackTreatment = fallbackTreatmentFromConfig(cfg.FallbackTreatment) s.URLs.PopulateWithDefaults() s.FeatureFlags.PopulateWithDefaults() s.Impressions.PopulateWithDefaults() @@ -216,6 +219,11 @@ func (s *SDK) ToSDKConf() *sdkConf.Config { if len(s.FlagSetsFilter) > 0 { cfg.FlagSetsFilter = s.FlagSetsFilter } + if parsed, err := (&s.FallbackTreatment).toConfig(); err != nil { + log.Printf("[splitd] fallbackTreatment: %v", err) + } else if parsed != nil { + cfg.FallbackTreatment = *parsed + } return cfg } @@ -310,6 +318,113 @@ func (p *Profiling) PopulateWithDefaults() { p.Port = 8888 } +// fallbackTreatmentFromConfig maps the SDK default config's FallbackTreatment into our input type. +func fallbackTreatmentFromConfig(c dtos.FallbackTreatmentConfig) fallbackTreatmentInput { + parsed := new(dtos.FallbackTreatmentConfig) + *parsed = c + return fallbackTreatmentInput{parsed: parsed} +} + +type fallbackTreatmentEntry struct { + Treatment *string `json:"treatment" yaml:"treatment"` + Config *string `json:"config,omitempty" yaml:"config,omitempty"` +} + +type fallbackTreatmentInput struct { + parsed *dtos.FallbackTreatmentConfig + raw string +} + +func (f *fallbackTreatmentInput) UnmarshalYAML(value *yaml.Node) error { + if value == nil { + return nil + } + switch value.Kind { + case yaml.ScalarNode: + var s string + if err := value.Decode(&s); err != nil { + return err + } + f.parsed = nil + f.raw = strings.TrimSpace(s) + return nil + case yaml.MappingNode: + var m struct { + Global *fallbackTreatmentEntry `yaml:"global_fallback_treatment"` + ByFlag map[string]fallbackTreatmentEntry `yaml:"by_flag_fallback_treatment"` + } + if err := value.Decode(&m); err != nil { + return err + } + out := dtos.FallbackTreatmentConfig{} + if m.Global != nil && m.Global.Treatment != nil { + out.GlobalFallbackTreatment = &dtos.FallbackTreatment{ + Treatment: m.Global.Treatment, + Config: m.Global.Config, + } + } + if len(m.ByFlag) > 0 { + out.ByFlagFallbackTreatment = make(map[string]dtos.FallbackTreatment) + for name, v := range m.ByFlag { + if v.Treatment != nil { + out.ByFlagFallbackTreatment[name] = dtos.FallbackTreatment{ + Treatment: v.Treatment, + Config: v.Config, + } + } + } + } + f.parsed = &out + return nil + } + return nil +} + +func (f *fallbackTreatmentInput) toConfig() (*dtos.FallbackTreatmentConfig, error) { + if f == nil { + return nil, nil + } + if f.raw != "" { + return parseFallbackTreatmentJSON(f.raw) + } + if f.parsed != nil { + return f.parsed, nil + } + return nil, nil +} + +func parseFallbackTreatmentJSON(raw string) (*dtos.FallbackTreatmentConfig, error) { + var wrapper struct { + FallbackTreatment struct { + GlobalFallbackTreatment *fallbackTreatmentEntry `json:"global_fallback_treatment"` + ByFlagFallbackTreatment map[string]fallbackTreatmentEntry `json:"by_flag_fallback_treatment"` + } `json:"fallback_treatment"` + } + if err := json.Unmarshal([]byte(raw), &wrapper); err != nil { + return nil, fmt.Errorf("invalid JSON: %w", err) + } + out := dtos.FallbackTreatmentConfig{} + inner := &wrapper.FallbackTreatment + if inner.GlobalFallbackTreatment != nil && inner.GlobalFallbackTreatment.Treatment != nil { + out.GlobalFallbackTreatment = &dtos.FallbackTreatment{ + Treatment: inner.GlobalFallbackTreatment.Treatment, + Config: inner.GlobalFallbackTreatment.Config, + } + } + if len(inner.ByFlagFallbackTreatment) > 0 { + out.ByFlagFallbackTreatment = make(map[string]dtos.FallbackTreatment) + for name, v := range inner.ByFlagFallbackTreatment { + if v.Treatment != nil { + out.ByFlagFallbackTreatment[name] = dtos.FallbackTreatment{ + Treatment: v.Treatment, + Config: v.Config, + } + } + } + } + return &out, nil +} + func ReadConfig() (*Config, error) { cfgFN := defaultConfigFN if fromEnv := os.Getenv("SPLITD_CONF_FILE"); fromEnv != "" { diff --git a/splitio/conf/splitd_test.go b/splitio/conf/splitd_test.go index d14c3c4..955513b 100644 --- a/splitio/conf/splitd_test.go +++ b/splitio/conf/splitd_test.go @@ -15,6 +15,8 @@ import ( "github.com/splitio/splitd/splitio/link/transfer" "github.com/splitio/splitd/splitio/sdk/conf" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) func TestConfig(t *testing.T) { @@ -30,6 +32,7 @@ func TestConfig(t *testing.T) { cfg = Config{} assert.Nil(t, cfg.parse(dir+string(filepath.Separator)+"splitd.yaml.tpl")) + expected.SDK.FallbackTreatment = cfg.SDK.FallbackTreatment assert.Equal(t, expected, cfg) assert.Error(t, cfg.parse("someNonexistantFile")) @@ -185,3 +188,44 @@ func TestDefaultConf(t *testing.T) { assert.Equal(t, defaultLogLevel, *c.Logger.Level) assert.Equal(t, defaultLogOutput, *c.Logger.Output) } + +func TestFallbackTreatmentToSDKConf(t *testing.T) { + // JSON string form + var cfg Config + cfg.PopulateWithDefaults() + err := yaml.Unmarshal([]byte(` +sdk: + apikey: test + fallbackTreatment: '{"fallback_treatment":{"global_fallback_treatment":{"treatment":"control"},"by_flag_fallback_treatment":{"my_flag":{"treatment":"off"}}}}' +`), &cfg) + assert.Nil(t, err) + sdkConf := cfg.SDK.ToSDKConf() + require.NotNil(t, sdkConf) + require.NotNil(t, sdkConf.FallbackTreatment.GlobalFallbackTreatment) + require.NotEmpty(t, sdkConf.FallbackTreatment.ByFlagFallbackTreatment) + assert.Equal(t, "control", *sdkConf.FallbackTreatment.GlobalFallbackTreatment.Treatment) + assert.Equal(t, "off", *sdkConf.FallbackTreatment.ByFlagFallbackTreatment["my_flag"].Treatment) + + // Native YAML object form + var cfg2 Config + cfg2.PopulateWithDefaults() + err = yaml.Unmarshal([]byte(` +sdk: + apikey: test + fallbackTreatment: + global_fallback_treatment: + treatment: global_val + by_flag_fallback_treatment: + some_flag: + treatment: on + config: "{}" +`), &cfg2) + assert.Nil(t, err) + sdkConf2 := cfg2.SDK.ToSDKConf() + require.NotNil(t, sdkConf2) + require.NotNil(t, sdkConf2.FallbackTreatment.GlobalFallbackTreatment) + require.Contains(t, sdkConf2.FallbackTreatment.ByFlagFallbackTreatment, "some_flag") + assert.Equal(t, "global_val", *sdkConf2.FallbackTreatment.GlobalFallbackTreatment.Treatment) + assert.Equal(t, "on", *sdkConf2.FallbackTreatment.ByFlagFallbackTreatment["some_flag"].Treatment) + assert.Equal(t, "{}", *sdkConf2.FallbackTreatment.ByFlagFallbackTreatment["some_flag"].Config) +} diff --git a/splitio/link/client/types/interfaces.go b/splitio/link/client/types/interfaces.go index 4e11855..a3da0bc 100644 --- a/splitio/link/client/types/interfaces.go +++ b/splitio/link/client/types/interfaces.go @@ -6,10 +6,10 @@ import ( ) type ClientInterface interface { - Treatment(key string, bucketingKey string, feature string, attrs map[string]interface{}) (*Result, error) - Treatments(key string, bucketingKey string, features []string, attrs map[string]interface{}) (Results, error) - TreatmentWithConfig(key string, bucketingKey string, feature string, attrs map[string]interface{}) (*Result, error) - TreatmentsWithConfig(key string, bucketingKey string, features []string, attrs map[string]interface{}) (Results, error) + Treatment(key string, bucketingKey string, feature string, attrs map[string]interface{}, optFns ...OptFn) (*Result, error) + Treatments(key string, bucketingKey string, features []string, attrs map[string]interface{}, optFns ...OptFn) (Results, error) + TreatmentWithConfig(key string, bucketingKey string, feature string, attrs map[string]interface{}, optFns ...OptFn) (*Result, error) + TreatmentsWithConfig(key string, bucketingKey string, features []string, attrs map[string]interface{}, optFns ...OptFn) (Results, error) Track(key string, trafficType string, eventType string, value *float64, properties map[string]interface{}) error SplitNames() ([]string, error) Split(name string) (*sdk.SplitView, error) @@ -24,3 +24,9 @@ type Result struct { } type Results = map[string]Result + +type Options struct { + EvaluationOptions *dtos.EvaluationOptions +} + +type OptFn = func(o *Options) diff --git a/splitio/link/client/v1/impl.go b/splitio/link/client/v1/impl.go index fba91ce..e33d070 100644 --- a/splitio/link/client/v1/impl.go +++ b/splitio/link/client/v1/impl.go @@ -26,6 +26,16 @@ type Impl struct { listenerFeedback bool } +func (c *Impl) WithEvaluationOptions(e *dtos.EvaluationOptions) types.OptFn { + return func(o *types.Options) { o.EvaluationOptions = e } +} + +func defaultOpts() types.Options { + return types.Options{ + EvaluationOptions: nil, + } +} + func New(id string, logger logging.LoggerInterface, conn transfer.RawConn, serializer serializer.Interface, listenerFeedback bool) (*Impl, error) { i := &Impl{ logger: logger, @@ -43,23 +53,27 @@ func New(id string, logger logging.LoggerInterface, conn transfer.RawConn, seria } // Treatment implements Interface -func (c *Impl) Treatment(key string, bucketingKey string, feature string, attrs map[string]interface{}) (*types.Result, error) { - return c.treatment(key, bucketingKey, feature, attrs, false) +func (c *Impl) Treatment(key string, bucketingKey string, feature string, attrs map[string]interface{}, optFns ...types.OptFn) (*types.Result, error) { + options := getOptions(optFns...) + return c.treatment(key, bucketingKey, feature, attrs, false, options.EvaluationOptions) } // TreatmentWithConfig implements types.ClientInterface -func (c *Impl) TreatmentWithConfig(key string, bucketingKey string, feature string, attrs map[string]interface{}) (*types.Result, error) { - return c.treatment(key, bucketingKey, feature, attrs, true) +func (c *Impl) TreatmentWithConfig(key string, bucketingKey string, feature string, attrs map[string]interface{}, optFns ...types.OptFn) (*types.Result, error) { + options := getOptions(optFns...) + return c.treatment(key, bucketingKey, feature, attrs, true, options.EvaluationOptions) } // Treatment implements Interface -func (c *Impl) Treatments(key string, bucketingKey string, features []string, attrs map[string]interface{}) (types.Results, error) { - return c.treatments(key, bucketingKey, features, attrs, false) +func (c *Impl) Treatments(key string, bucketingKey string, features []string, attrs map[string]interface{}, optFns ...types.OptFn) (types.Results, error) { + options := getOptions(optFns...) + return c.treatments(key, bucketingKey, features, attrs, false, options.EvaluationOptions) } // TreatmentsWithConfig implements types.ClientInterface -func (c *Impl) TreatmentsWithConfig(key string, bucketingKey string, features []string, attrs map[string]interface{}) (types.Results, error) { - return c.treatments(key, bucketingKey, features, attrs, true) +func (c *Impl) TreatmentsWithConfig(key string, bucketingKey string, features []string, attrs map[string]interface{}, optFns ...types.OptFn) (types.Results, error) { + options := getOptions(optFns...) + return c.treatments(key, bucketingKey, features, attrs, true, options.EvaluationOptions) } // Track implements types.ClientInterface @@ -139,7 +153,7 @@ func (c *Impl) Splits() ([]sdk.SplitView, error) { return views, nil } -func (c *Impl) treatment(key string, bucketingKey string, feature string, attrs map[string]interface{}, withConfig bool) (*types.Result, error) { +func (c *Impl) treatment(key string, bucketingKey string, feature string, attrs map[string]interface{}, withConfig bool, evaluationOptions *dtos.EvaluationOptions) (*types.Result, error) { var bkp *string if bucketingKey != "" { bkp = &bucketingKey @@ -174,6 +188,7 @@ func (c *Impl) treatment(key string, bucketingKey string, feature string, attrs ChangeNumber: resp.Payload.ListenerData.ChangeNumber, Label: resp.Payload.ListenerData.Label, BucketingKey: bucketingKey, + Properties: sdk.SerializeProperties(evaluationOptions), } } @@ -185,7 +200,7 @@ func (c *Impl) treatment(key string, bucketingKey string, feature string, attrs return toRet, nil } -func (c *Impl) treatments(key string, bucketingKey string, features []string, attrs map[string]interface{}, withConfig bool) (types.Results, error) { +func (c *Impl) treatments(key string, bucketingKey string, features []string, attrs map[string]interface{}, withConfig bool, evaluationOptions *dtos.EvaluationOptions) (types.Results, error) { var bkp *string if bucketingKey != "" { bkp = &bucketingKey @@ -221,6 +236,7 @@ func (c *Impl) treatments(key string, bucketingKey string, features []string, at ChangeNumber: resp.Payload.Results[idx].ListenerData.ChangeNumber, Label: resp.Payload.Results[idx].ListenerData.Label, BucketingKey: bucketingKey, + Properties: sdk.SerializeProperties(evaluationOptions), } } @@ -286,4 +302,12 @@ func (c *Impl) Shutdown() error { return c.conn.Shutdown() } +func getOptions(optFns ...types.OptFn) types.Options { + options := defaultOpts() + for _, optFn := range optFns { + optFn(&options) + } + return options +} + var _ types.ClientInterface = (*Impl)(nil) diff --git a/splitio/link/client/v1/impl_test.go b/splitio/link/client/v1/impl_test.go index 2bdbee6..8801c14 100644 --- a/splitio/link/client/v1/impl_test.go +++ b/splitio/link/client/v1/impl_test.go @@ -31,7 +31,7 @@ func TestClientGetTreatmentNoImpression(t *testing.T) { *args.Get(1).(*v1.ResponseWrapper[v1.RegisterPayload]) = v1.ResponseWrapper[v1.RegisterPayload]{Status: v1.ResultOk} }).Once() - serializerMock.On("Serialize", proto1Mocks.NewTreatmentRPC("key1", "buck1", "feat1", map[string]interface{}{"a": 1}, false)). + serializerMock.On("Serialize", proto1Mocks.NewTreatmentRPC("key1", "buck1", "feat1", map[string]interface{}{"a": 1}, nil, false)). Return([]byte("treatmentMessage"), nil).Once() serializerMock.On("Parse", []byte("treatmentResult"), mock.Anything).Return(nil).Run(func(args mock.Arguments) { *args.Get(1).(*v1.ResponseWrapper[v1.TreatmentPayload]) = v1.ResponseWrapper[v1.TreatmentPayload]{ @@ -65,7 +65,7 @@ func TestClientGetTreatmentWithConfig(t *testing.T) { *args.Get(1).(*v1.ResponseWrapper[v1.RegisterPayload]) = v1.ResponseWrapper[v1.RegisterPayload]{Status: v1.ResultOk} }).Once() - serializerMock.On("Serialize", proto1Mocks.NewTreatmentRPC("key1", "buck1", "feat1", map[string]interface{}{"a": 1}, true)). + serializerMock.On("Serialize", proto1Mocks.NewTreatmentRPC("key1", "buck1", "feat1", map[string]interface{}{"a": 1}, nil, true)). Return([]byte("treatmentWithConfigMessage"), nil).Once() serializerMock.On("Parse", []byte("treatmentWithConfigResult"), mock.Anything).Return(nil).Run(func(args mock.Arguments) { *args.Get(1).(*v1.ResponseWrapper[v1.TreatmentPayload]) = v1.ResponseWrapper[v1.TreatmentPayload]{ @@ -129,7 +129,7 @@ func TestClientGetTreatmentWithImpression(t *testing.T) { *args.Get(1).(*v1.ResponseWrapper[v1.RegisterPayload]) = v1.ResponseWrapper[v1.RegisterPayload]{Status: v1.ResultOk} }).Once() - serializerMock.On("Serialize", proto1Mocks.NewTreatmentRPC("key1", "buck1", "feat1", map[string]interface{}{"a": 1}, false)). + serializerMock.On("Serialize", proto1Mocks.NewTreatmentRPC("key1", "buck1", "feat1", map[string]interface{}{"a": 1}, nil, false)). Return([]byte("treatmentMessage"), nil).Once() serializerMock.On("Parse", []byte("treatmentResult"), mock.Anything).Return(nil).Run(func(args mock.Arguments) { *args.Get(1).(*v1.ResponseWrapper[v1.TreatmentPayload]) = v1.ResponseWrapper[v1.TreatmentPayload]{ @@ -144,7 +144,12 @@ func TestClientGetTreatmentWithImpression(t *testing.T) { assert.NotNil(t, client) assert.Nil(t, err) - res, err := client.Treatment("key1", "buck1", "feat1", map[string]interface{}{"a": 1}) + opts := dtos.EvaluationOptions{ + Properties: map[string]interface{}{ + "pleassssse": "holaaaaa", + }, + } + res, err := client.Treatment("key1", "buck1", "feat1", map[string]interface{}{"a": 1}, client.WithEvaluationOptions(&opts)) assert.Nil(t, err) assert.Equal(t, "on", res.Treatment) validateImpression(t, &dtos.Impression{ @@ -155,6 +160,7 @@ func TestClientGetTreatmentWithImpression(t *testing.T) { Label: "l1", ChangeNumber: 1234, Time: 123, + Properties: "{\"pleassssse\":\"holaaaaa\"}", }, res.Impression) } @@ -435,5 +441,6 @@ func validateImpression(t *testing.T, expected *dtos.Impression, actual *dtos.Im assert.Equal(t, expected.Time, actual.Time) assert.Equal(t, expected.Treatment, actual.Treatment) assert.Equal(t, expected.Label, actual.Label) + assert.Equal(t, expected.Properties, actual.Properties) } diff --git a/splitio/link/protocol/v1/mocks/mocks.go b/splitio/link/protocol/v1/mocks/mocks.go index cf56663..1473e99 100644 --- a/splitio/link/protocol/v1/mocks/mocks.go +++ b/splitio/link/protocol/v1/mocks/mocks.go @@ -22,11 +22,11 @@ func NewRegisterRPC(id string, listener bool) *v1.RPC { } } -func NewTreatmentRPC(key string, bucketing string, feature string, attrs map[string]interface{}, withConfig bool) *v1.RPC { +func NewTreatmentRPC(key string, bucketing string, feature string, attrs map[string]interface{}, impressionProperties map[string]interface{}, withConfig bool) *v1.RPC { rpc := &v1.RPC{ RPCBase: protocol.RPCBase{Version: protocol.V1}, OpCode: v1.OCTreatment, - Args: []interface{}{key, bucketing, feature, attrs}, + Args: []interface{}{key, bucketing, feature, attrs, impressionProperties}, } if withConfig { rpc.OpCode = v1.OCTreatmentWithConfig diff --git a/splitio/link/protocol/v1/rpcs.go b/splitio/link/protocol/v1/rpcs.go index 2a8a741..368ab1c 100644 --- a/splitio/link/protocol/v1/rpcs.go +++ b/splitio/link/protocol/v1/rpcs.go @@ -123,17 +123,19 @@ func (r *RegisterArgs) PopulateFromRPC(rpc *RPC) error { } const ( - TreatmentArgKeyIdx int = 0 - TreatmentArgBucketingKeyIdx int = 1 - TreatmentArgFeatureIdx int = 2 - TreatmentArgAttributesIdx int = 3 + TreatmentArgKeyIdx int = 0 + TreatmentArgBucketingKeyIdx int = 1 + TreatmentArgFeatureIdx int = 2 + TreatmentArgAttributesIdx int = 3 + TreatmentArgImpressionPropertiesIdx int = 4 ) type TreatmentArgs struct { - Key string `msgpack:"k"` - BucketingKey *string `msgpack:"b"` - Feature string `msgpack:"f"` - Attributes map[string]interface{} `msgpack:"a"` + Key string `msgpack:"k"` + BucketingKey *string `msgpack:"b"` + Feature string `msgpack:"f"` + Attributes map[string]interface{} `msgpack:"a"` + ImpressionProperties map[string]interface{} `msgpack:"i"` } func (r TreatmentArgs) Encode() []interface{} { @@ -141,14 +143,14 @@ func (r TreatmentArgs) Encode() []interface{} { if r.BucketingKey != nil { bk = *r.BucketingKey } - return []interface{}{r.Key, bk, r.Feature, r.Attributes} + return []interface{}{r.Key, bk, r.Feature, r.Attributes, r.ImpressionProperties} } func (t *TreatmentArgs) PopulateFromRPC(rpc *RPC) error { if rpc.OpCode != OCTreatment && rpc.OpCode != OCTreatmentWithConfig { return RPCParseError{Code: PECOpCodeMismatch} } - if len(rpc.Args) != 4 { + if len(rpc.Args) < 4 { return RPCParseError{Code: PECWrongArgCount} } @@ -177,21 +179,31 @@ func (t *TreatmentArgs) PopulateFromRPC(rpc *RPC) error { t.Attributes = sanitizeAttributes(rawAttrs) } + if len(rpc.Args) >= 5 && rpc.Args[TreatmentArgImpressionPropertiesIdx] != nil { + rawAttrs, err := getOptional[map[string]interface{}](rpc.Args[TreatmentArgImpressionPropertiesIdx]) + if err != nil { + return RPCParseError{Code: PECInvalidArgType, Data: int64(TreatmentArgImpressionPropertiesIdx)} + } + t.ImpressionProperties = rawAttrs + } + return nil } const ( - TreatmentsArgKeyIdx int = 0 - TreatmentsArgBucketingKeyIdx int = 1 - TreatmentsArgFeaturesIdx int = 2 - TreatmentsArgAttributesIdx int = 3 + TreatmentsArgKeyIdx int = 0 + TreatmentsArgBucketingKeyIdx int = 1 + TreatmentsArgFeaturesIdx int = 2 + TreatmentsArgAttributesIdx int = 3 + TreatmentsArgImpressionPropertiesIdx int = 4 ) type TreatmentsArgs struct { - Key string `msgpack:"k"` - BucketingKey *string `msgpack:"b"` - Features []string `msgpack:"f"` - Attributes map[string]interface{} `msgpack:"a"` + Key string `msgpack:"k"` + BucketingKey *string `msgpack:"b"` + Features []string `msgpack:"f"` + Attributes map[string]interface{} `msgpack:"a"` + ImpressionProperties map[string]interface{} `msgpack:"i"` } func (r TreatmentsArgs) Encode() []interface{} { @@ -206,7 +218,7 @@ func (t *TreatmentsArgs) PopulateFromRPC(rpc *RPC) error { if rpc.OpCode != OCTreatments && rpc.OpCode != OCTreatmentsWithConfig { return RPCParseError{Code: PECOpCodeMismatch} } - if len(rpc.Args) != 4 { + if len(rpc.Args) < 4 { return RPCParseError{Code: PECWrongArgCount} } @@ -238,21 +250,31 @@ func (t *TreatmentsArgs) PopulateFromRPC(rpc *RPC) error { } t.Attributes = sanitizeAttributes(rawAttrs) + if len(rpc.Args) >= 5 && rpc.Args[TreatmentsArgImpressionPropertiesIdx] != nil { + rawAttrs, err := getOptional[map[string]interface{}](rpc.Args[TreatmentsArgImpressionPropertiesIdx]) + if err != nil { + return RPCParseError{Code: PECInvalidArgType, Data: int64(TreatmentsArgImpressionPropertiesIdx)} + } + t.ImpressionProperties = rawAttrs + } + return nil } const ( - TreatmentsByFlagSetArgKeyIdx int = 0 - TreatmentsByFlagSetArgBucketingKeyIdx int = 1 - TreatmentsByFlagSetArgFlagSetIdx int = 2 - TreatmentsByFlagSetArgAttributesIdx int = 3 + TreatmentsByFlagSetArgKeyIdx int = 0 + TreatmentsByFlagSetArgBucketingKeyIdx int = 1 + TreatmentsByFlagSetArgFlagSetIdx int = 2 + TreatmentsByFlagSetArgAttributesIdx int = 3 + TreatmentsByFlagSetArgImpressionPropertiesIdx int = 4 ) type TreatmentsByFlagSetArgs struct { - Key string `msgpack:"k"` - BucketingKey *string `msgpack:"b"` - FlagSet string `msgpack:"f"` - Attributes map[string]interface{} `msgpack:"a"` + Key string `msgpack:"k"` + BucketingKey *string `msgpack:"b"` + FlagSet string `msgpack:"f"` + Attributes map[string]interface{} `msgpack:"a"` + ImpressionProperties map[string]interface{} `msgpack:"i"` } func (r TreatmentsByFlagSetArgs) Encode() []interface{} { @@ -267,7 +289,7 @@ func (t *TreatmentsByFlagSetArgs) PopulateFromRPC(rpc *RPC) error { if rpc.OpCode != OCTreatmentsByFlagSet && rpc.OpCode != OCTreatmentsWithConfigByFlagSet { return RPCParseError{Code: PECOpCodeMismatch} } - if len(rpc.Args) != 4 { + if len(rpc.Args) < 4 { return RPCParseError{Code: PECWrongArgCount} } @@ -292,21 +314,31 @@ func (t *TreatmentsByFlagSetArgs) PopulateFromRPC(rpc *RPC) error { } t.Attributes = sanitizeAttributes(rawAttrs) + if len(rpc.Args) >= 5 && rpc.Args[TreatmentsByFlagSetArgImpressionPropertiesIdx] != nil { + rawAttrs, err := getOptional[map[string]interface{}](rpc.Args[TreatmentsByFlagSetArgImpressionPropertiesIdx]) + if err != nil { + return RPCParseError{Code: PECInvalidArgType, Data: int64(TreatmentsByFlagSetArgImpressionPropertiesIdx)} + } + t.ImpressionProperties = rawAttrs + } + return nil } const ( - TreatmentsByFlagSetsArgKeyIdx int = 0 - TreatmentsByFlagSetsArgBucketingKeyIdx int = 1 - TreatmentsByFlagSetsArgFlagSetsIdx int = 2 - TreatmentsByFlagSetsArgAttributesIdx int = 3 + TreatmentsByFlagSetsArgKeyIdx int = 0 + TreatmentsByFlagSetsArgBucketingKeyIdx int = 1 + TreatmentsByFlagSetsArgFlagSetsIdx int = 2 + TreatmentsByFlagSetsArgAttributesIdx int = 3 + TreatmentsByFlagSetsArgImpressionPropertiesIdx int = 4 ) type TreatmentsByFlagSetsArgs struct { - Key string `msgpack:"k"` - BucketingKey *string `msgpack:"b"` - FlagSets []string `msgpack:"f"` - Attributes map[string]interface{} `msgpack:"a"` + Key string `msgpack:"k"` + BucketingKey *string `msgpack:"b"` + FlagSets []string `msgpack:"f"` + Attributes map[string]interface{} `msgpack:"a"` + ImpressionProperties map[string]interface{} `msgpack:"i"` } func (r TreatmentsByFlagSetsArgs) Encode() []interface{} { @@ -321,7 +353,7 @@ func (t *TreatmentsByFlagSetsArgs) PopulateFromRPC(rpc *RPC) error { if rpc.OpCode != OCTreatmentsByFlagSets && rpc.OpCode != OCTreatmentsWithConfigByFlagSets { return RPCParseError{Code: PECOpCodeMismatch} } - if len(rpc.Args) != 4 { + if len(rpc.Args) < 4 { return RPCParseError{Code: PECWrongArgCount} } @@ -352,6 +384,14 @@ func (t *TreatmentsByFlagSetsArgs) PopulateFromRPC(rpc *RPC) error { } t.Attributes = sanitizeAttributes(rawAttrs) + if len(rpc.Args) >= 5 && rpc.Args[TreatmentsByFlagSetsArgImpressionPropertiesIdx] != nil { + rawAttrs, err := getOptional[map[string]interface{}](rpc.Args[TreatmentsByFlagSetsArgImpressionPropertiesIdx]) + if err != nil { + return RPCParseError{Code: PECInvalidArgType, Data: int64(TreatmentsByFlagSetsArgImpressionPropertiesIdx)} + } + t.ImpressionProperties = rawAttrs + } + return nil } diff --git a/splitio/link/service/v1/clientmgr.go b/splitio/link/service/v1/clientmgr.go index 6c74f4f..eb2388b 100644 --- a/splitio/link/service/v1/clientmgr.go +++ b/splitio/link/service/v1/clientmgr.go @@ -7,6 +7,7 @@ import ( "os" "runtime/debug" + "github.com/splitio/go-split-commons/v9/dtos" "github.com/splitio/go-toolkit/v5/logging" protov1 "github.com/splitio/splitd/splitio/link/protocol/v1" @@ -169,7 +170,7 @@ func (m *ClientManager) handleGetTreatment(rpc *protov1.RPC, withConfig bool) (i return nil, fmt.Errorf("error parsing treatment arguments: %w", err) } - res, err := m.splitSDK.Treatment(m.clientConfig, args.Key, args.BucketingKey, args.Feature, args.Attributes) + res, err := m.splitSDK.Treatment(m.clientConfig, args.Key, args.BucketingKey, args.Feature, args.Attributes, &dtos.EvaluationOptions{Properties: args.ImpressionProperties}) if err != nil { return &protov1.ResponseWrapper[protov1.TreatmentPayload]{Status: protov1.ResultInternalError}, err } @@ -201,7 +202,7 @@ func (m *ClientManager) handleGetTreatments(rpc *protov1.RPC, withConfig bool) ( return nil, fmt.Errorf("error parsing treatments arguments: %w", err) } - res, err := m.splitSDK.Treatments(m.clientConfig, args.Key, args.BucketingKey, args.Features, args.Attributes) + res, err := m.splitSDK.Treatments(m.clientConfig, args.Key, args.BucketingKey, args.Features, args.Attributes, &dtos.EvaluationOptions{Properties: args.ImpressionProperties}) if err != nil { return &protov1.ResponseWrapper[protov1.TreatmentsPayload]{Status: protov1.ResultInternalError}, err } @@ -243,7 +244,7 @@ func (m *ClientManager) handleGetTreatmentsByFlagSet(rpc *protov1.RPC, withConfi return nil, fmt.Errorf("error parsing treatmentsByFlagSet arguments: %w", err) } - res, err := m.splitSDK.TreatmentsByFlagSet(m.clientConfig, args.Key, args.BucketingKey, args.FlagSet, args.Attributes) + res, err := m.splitSDK.TreatmentsByFlagSet(m.clientConfig, args.Key, args.BucketingKey, args.FlagSet, args.Attributes, &dtos.EvaluationOptions{Properties: args.ImpressionProperties}) if err != nil { return &protov1.ResponseWrapper[protov1.TreatmentsWithFeaturePayload]{Status: protov1.ResultInternalError}, err } @@ -284,7 +285,7 @@ func (m *ClientManager) handleGetTreatmentsByFlagSets(rpc *protov1.RPC, withConf return nil, fmt.Errorf("error parsing treatmentsByFlagSets arguments: %w", err) } - res, err := m.splitSDK.TreatmentsByFlagSets(m.clientConfig, args.Key, args.BucketingKey, args.FlagSets, args.Attributes) + res, err := m.splitSDK.TreatmentsByFlagSets(m.clientConfig, args.Key, args.BucketingKey, args.FlagSets, args.Attributes, &dtos.EvaluationOptions{Properties: args.ImpressionProperties}) if err != nil { return &protov1.ResponseWrapper[protov1.TreatmentsWithFeaturePayload]{Status: protov1.ResultInternalError}, err } diff --git a/splitio/sdk/integration_test.go b/splitio/sdk/integration_test.go index 97ccdaf..7a39227 100644 --- a/splitio/sdk/integration_test.go +++ b/splitio/sdk/integration_test.go @@ -123,7 +123,7 @@ func TestInstantiationAndGetTreatmentE2E(t *testing.T) { client, err := New(logger, "someApikey", sdkConf) assert.Nil(t, err) - res, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "split", nil) + res, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "split", nil, nil) assert.Equal(t, "on", res.Treatment) assert.Nil(t, client.Shutdown()) @@ -241,7 +241,7 @@ func TestInstantiationAndGetTreatmentE2EWithFallbackTreatment(t *testing.T) { }, } - res1, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "not_exist", nil, client.WithEvaluationOptions(&opts)) + res1, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "not_exist", nil, &opts) assert.Equal(t, "global_treatment", res1.Treatment) assert.Equal(t, "{\"pleassssse\":\"holaaaaa\"}", res1.Impression.Properties) @@ -370,7 +370,7 @@ func TestInstantiationAndGetTreatmentE2EWithPrerequistesNotAchive(t *testing.T) client, err := New(logger, "someApikey", sdkConf) assert.Nil(t, err) - res1, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "split", nil) + res1, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "split", nil, nil) assert.Equal(t, "default", res1.Treatment) assert.Nil(t, client.Shutdown()) @@ -498,7 +498,7 @@ func TestInstantiationAndGetTreatmentE2EWithPrerequistesAchive(t *testing.T) { client, err := New(logger, "someApikey", sdkConf) assert.Nil(t, err) - res1, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "split", nil) + res1, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "split", nil, nil) assert.Equal(t, "on", res1.Treatment) assert.Nil(t, client.Shutdown()) @@ -649,7 +649,7 @@ func TestInstantiationAndGetTreatmentE2EWithRBS(t *testing.T) { attributes := make(map[string]interface{}) attributes["version"] = "3.4.5" - res1, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "split", attributes) + res1, _ := client.Treatment(&types.ClientConfig{}, "aaaaaaklmnbv", nil, "split", attributes, nil) assert.Equal(t, "on", res1.Treatment) assert.Nil(t, client.Shutdown()) diff --git a/splitio/sdk/mocks/sdk.go b/splitio/sdk/mocks/sdk.go index dd6abb6..1d349f4 100644 --- a/splitio/sdk/mocks/sdk.go +++ b/splitio/sdk/mocks/sdk.go @@ -1,6 +1,7 @@ package mocks import ( + "github.com/splitio/go-split-commons/v9/dtos" "github.com/splitio/splitd/splitio/sdk" "github.com/splitio/splitd/splitio/sdk/types" "github.com/stretchr/testify/mock" @@ -17,7 +18,7 @@ func (m *SDKMock) Treatment( bucketingKey *string, feature string, attributes map[string]interface{}, - optFns ...sdk.OptFn, + evaluationOptions *dtos.EvaluationOptions, ) (*sdk.EvaluationResult, error) { args := m.Called(md, key, bucketingKey, feature, attributes) return args.Get(0).(*sdk.EvaluationResult), args.Error(1) @@ -30,7 +31,7 @@ func (m *SDKMock) Treatments( bucketingKey *string, features []string, attributes map[string]interface{}, - optFns ...sdk.OptFn, + evaluationOptions *dtos.EvaluationOptions, ) (map[string]sdk.EvaluationResult, error) { args := m.Called(md, key, bucketingKey, features, attributes) return args.Get(0).(map[string]sdk.EvaluationResult), args.Error(1) @@ -43,7 +44,7 @@ func (m *SDKMock) TreatmentsByFlagSet( bucketingKey *string, flagSet string, attributes map[string]interface{}, - optFns ...sdk.OptFn, + evaluationOptions *dtos.EvaluationOptions, ) (map[string]sdk.EvaluationResult, error) { args := m.Called(md, key, bucketingKey, flagSet, attributes) return args.Get(0).(map[string]sdk.EvaluationResult), args.Error(1) @@ -56,7 +57,7 @@ func (m *SDKMock) TreatmentsByFlagSets( bucketingKey *string, flagSets []string, attributes map[string]interface{}, - optFns ...sdk.OptFn, + evaluationOptions *dtos.EvaluationOptions, ) (map[string]sdk.EvaluationResult, error) { args := m.Called(md, key, bucketingKey, flagSets, attributes) return args.Get(0).(map[string]sdk.EvaluationResult), args.Error(1) diff --git a/splitio/sdk/sdk.go b/splitio/sdk/sdk.go index 64febce..7c97825 100644 --- a/splitio/sdk/sdk.go +++ b/splitio/sdk/sdk.go @@ -49,10 +49,10 @@ var ruleBasedSegmentRules = []string{constants.MatcherTypeAllKeys, constants.Mat type Attributes = map[string]interface{} type Interface interface { - Treatment(cfg *types.ClientConfig, key string, bucketingKey *string, feature string, attributes map[string]interface{}, optFns ...OptFn) (*EvaluationResult, error) - Treatments(cfg *types.ClientConfig, key string, bucketingKey *string, features []string, attributes map[string]interface{}, optFns ...OptFn) (map[string]EvaluationResult, error) - TreatmentsByFlagSet(cfg *types.ClientConfig, key string, bucketingKey *string, flagSet string, attributes map[string]interface{}, optFns ...OptFn) (map[string]EvaluationResult, error) - TreatmentsByFlagSets(cfg *types.ClientConfig, key string, bucketingKey *string, flagSets []string, attributes map[string]interface{}, optFns ...OptFn) (map[string]EvaluationResult, error) + Treatment(cfg *types.ClientConfig, key string, bucketingKey *string, feature string, attributes map[string]interface{}, evaluationOptions *dtos.EvaluationOptions) (*EvaluationResult, error) + Treatments(cfg *types.ClientConfig, key string, bucketingKey *string, features []string, attributes map[string]interface{}, evaluationOptions *dtos.EvaluationOptions) (map[string]EvaluationResult, error) + TreatmentsByFlagSet(cfg *types.ClientConfig, key string, bucketingKey *string, flagSet string, attributes map[string]interface{}, evaluationOptions *dtos.EvaluationOptions) (map[string]EvaluationResult, error) + TreatmentsByFlagSets(cfg *types.ClientConfig, key string, bucketingKey *string, flagSets []string, attributes map[string]interface{}, evaluationOptions *dtos.EvaluationOptions) (map[string]EvaluationResult, error) Track(cfg *types.ClientConfig, key string, trafficType string, eventType string, value *float64, properties map[string]interface{}) error SplitNames() ([]string, error) Splits() ([]SplitView, error) @@ -74,22 +74,6 @@ type Impl struct { validator Validator } -type options struct { - evaluationOptions *dtos.EvaluationOptions -} - -type OptFn = func(o *options) - -func (c *Impl) WithEvaluationOptions(e *dtos.EvaluationOptions) OptFn { - return func(o *options) { o.evaluationOptions = e } -} - -func defaultOpts() options { - return options{ - evaluationOptions: nil, - } -} - func New(logger logging.LoggerInterface, apikey string, c *conf.Config) (*Impl, error) { if warnings := c.Normalize(); len(warnings) > 0 { @@ -149,47 +133,50 @@ func New(logger logging.LoggerInterface, apikey string, c *conf.Config) (*Impl, } // Treatment implements Interface -func (i *Impl) Treatment(cfg *types.ClientConfig, key string, bk *string, feature string, attributes Attributes, optFns ...OptFn) (*EvaluationResult, error) { - options := getOptions(optFns...) +func (i *Impl) Treatment(cfg *types.ClientConfig, key string, bk *string, feature string, attributes Attributes, evaluationOptions *dtos.EvaluationOptions) (*EvaluationResult, error) { res := i.ev.EvaluateFeature(key, bk, feature, attributes) if res == nil { return nil, fmt.Errorf("nil result") } - - imp := i.handleImpression(key, bk, feature, res, cfg.Metadata, serializeProperties(options.evaluationOptions)) + treatment := res.Treatment + config := res.Config + if treatment == defaultFallbackTreatment { + if t, c := i.getFallbackTreatment(feature); t != defaultFallbackTreatment { + treatment, config = t, c + } + } + imp := i.handleImpression(key, bk, feature, res, cfg.Metadata, SerializeProperties(evaluationOptions)) return &EvaluationResult{ - Treatment: res.Treatment, + Treatment: treatment, Impression: imp, - Config: res.Config, + Config: config, }, nil } -func getOptions(optFns ...OptFn) options { - options := defaultOpts() - for _, optFn := range optFns { - optFn(&options) - } - return options -} - // Treatment implements Interface -func (i *Impl) Treatments(cfg *types.ClientConfig, key string, bk *string, features []string, attributes Attributes, optFns ...OptFn) (map[string]EvaluationResult, error) { +func (i *Impl) Treatments(cfg *types.ClientConfig, key string, bk *string, features []string, attributes Attributes, evaluationOptions *dtos.EvaluationOptions) (map[string]EvaluationResult, error) { - options := getOptions(optFns...) res := i.ev.EvaluateFeatures(key, bk, features, attributes) toRet := make(map[string]EvaluationResult, len(res.Evaluations)) for _, feature := range features { curr, ok := res.Evaluations[feature] if !ok { - toRet[feature] = EvaluationResult{Treatment: "control"} + treatment, config := i.getFallbackTreatment(feature) + toRet[feature] = EvaluationResult{Treatment: treatment, Config: config} continue } - + treatment := curr.Treatment + config := curr.Config + if treatment == defaultFallbackTreatment { + if t, c := i.getFallbackTreatment(feature); t != defaultFallbackTreatment { + treatment, config = t, c + } + } var eres EvaluationResult - eres.Treatment = curr.Treatment - eres.Impression = i.handleImpression(key, bk, feature, &curr, cfg.Metadata, serializeProperties(options.evaluationOptions)) - eres.Config = curr.Config + eres.Treatment = treatment + eres.Impression = i.handleImpression(key, bk, feature, &curr, cfg.Metadata, SerializeProperties(evaluationOptions)) + eres.Config = config toRet[feature] = eres } @@ -197,16 +184,21 @@ func (i *Impl) Treatments(cfg *types.ClientConfig, key string, bk *string, featu } // TreatmentsByFlagSet implements Interface -func (i *Impl) TreatmentsByFlagSet(cfg *types.ClientConfig, key string, bk *string, flagSet string, attributes Attributes, optFns ...OptFn) (map[string]EvaluationResult, error) { +func (i *Impl) TreatmentsByFlagSet(cfg *types.ClientConfig, key string, bk *string, flagSet string, attributes Attributes, evaluationOptions *dtos.EvaluationOptions) (map[string]EvaluationResult, error) { - options := getOptions(optFns...) res := i.ev.EvaluateFeatureByFlagSets(key, bk, []string{flagSet}, attributes) toRet := make(map[string]EvaluationResult, len(res.Evaluations)) for feature, curr := range res.Evaluations { + treatment, config := curr.Treatment, curr.Config + if treatment == defaultFallbackTreatment { + if t, c := i.getFallbackTreatment(feature); t != defaultFallbackTreatment { + treatment, config = t, c + } + } var eres EvaluationResult - eres.Treatment = curr.Treatment - eres.Impression = i.handleImpression(key, bk, feature, &curr, cfg.Metadata, serializeProperties(options.evaluationOptions)) - eres.Config = curr.Config + eres.Treatment = treatment + eres.Impression = i.handleImpression(key, bk, feature, &curr, cfg.Metadata, SerializeProperties(evaluationOptions)) + eres.Config = config toRet[feature] = eres } @@ -214,16 +206,21 @@ func (i *Impl) TreatmentsByFlagSet(cfg *types.ClientConfig, key string, bk *stri } // TreatmentsByFlagSets implements Interface -func (i *Impl) TreatmentsByFlagSets(cfg *types.ClientConfig, key string, bk *string, flagSets []string, attributes Attributes, optFns ...OptFn) (map[string]EvaluationResult, error) { +func (i *Impl) TreatmentsByFlagSets(cfg *types.ClientConfig, key string, bk *string, flagSets []string, attributes Attributes, evaluationOptions *dtos.EvaluationOptions) (map[string]EvaluationResult, error) { - options := getOptions(optFns...) res := i.ev.EvaluateFeatureByFlagSets(key, bk, flagSets, attributes) toRet := make(map[string]EvaluationResult, len(res.Evaluations)) for feature, curr := range res.Evaluations { + treatment, config := curr.Treatment, curr.Config + if treatment == defaultFallbackTreatment { + if t, c := i.getFallbackTreatment(feature); t != defaultFallbackTreatment { + treatment, config = t, c + } + } var eres EvaluationResult - eres.Treatment = curr.Treatment - eres.Impression = i.handleImpression(key, bk, feature, &curr, cfg.Metadata, serializeProperties(options.evaluationOptions)) - eres.Config = curr.Config + eres.Treatment = treatment + eres.Impression = i.handleImpression(key, bk, feature, &curr, cfg.Metadata, SerializeProperties(evaluationOptions)) + eres.Config = config toRet[feature] = eres } @@ -354,6 +351,23 @@ func splitToView(s *dtos.SplitDTO) *SplitView { } } +const defaultFallbackTreatment = "control" + +func (i *Impl) getFallbackTreatment(feature string) (treatment string, config *string) { + treatment = defaultFallbackTreatment + ft := i.cfg.FallbackTreatment + if byFlag, ok := ft.ByFlagFallbackTreatment[feature]; ok && byFlag.Treatment != nil { + treatment = *byFlag.Treatment + config = byFlag.Config + return treatment, config + } + if ft.GlobalFallbackTreatment != nil && ft.GlobalFallbackTreatment.Treatment != nil { + treatment = *ft.GlobalFallbackTreatment.Treatment + config = ft.GlobalFallbackTreatment.Config + } + return treatment, config +} + func timeMillis() int64 { return time.Now().UTC().UnixMilli() } @@ -367,7 +381,7 @@ func createFallbackTreatmentCalculator(fallbackTreatmentConfig *dtos.FallbackTre return dtos.NewFallbackTreatmentCalculatorImp(&fallbackTreatmentConf) } -func serializeProperties(opts *dtos.EvaluationOptions) string { +func SerializeProperties(opts *dtos.EvaluationOptions) string { if opts == nil { return "" } diff --git a/splitio/sdk/sdk_test.go b/splitio/sdk/sdk_test.go index 1202487..230534f 100644 --- a/splitio/sdk/sdk_test.go +++ b/splitio/sdk/sdk_test.go @@ -61,7 +61,7 @@ func TestTreatmentsWithImpressionsDisabled(t *testing.T) { res, err := client.Treatments( &types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, - "key1", nil, []string{"f1", "f2", "f3"}, Attributes{"a": 1}) + "key1", nil, []string{"f1", "f2", "f3"}, Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res["f1"].Config) assert.Nil(t, res["f2"].Config) @@ -101,7 +101,7 @@ func TestTreatmentLabelsDisabled(t *testing.T) { cfg: conf.Config{LabelsEnabled: false}, } - res, err := client.Treatment(&types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, "key1", nil, "f1", Attributes{"a": 1}) + res, err := client.Treatment(&types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, "key1", nil, "f1", Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res.Config) assertImpEq(t, expectedImpression, res.Impression) @@ -156,7 +156,7 @@ func TestTreatmentLabelsEnabled(t *testing.T) { cfg: conf.Config{LabelsEnabled: true}, } - res, err := client.Treatment(&types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, "key1", nil, "f1", Attributes{"a": 1}) + res, err := client.Treatment(&types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, "key1", nil, "f1", Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res.Config) assertImpEq(t, expectedImpression, res.Impression) @@ -226,7 +226,7 @@ func TestTreatments(t *testing.T) { res, err := client.Treatments( &types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, - "key1", nil, []string{"f1", "f2", "f3"}, Attributes{"a": 1}) + "key1", nil, []string{"f1", "f2", "f3"}, Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res["f1"].Config) assert.Nil(t, res["f2"].Config) @@ -308,7 +308,7 @@ func TestTreatmentsWithImpressionProperties(t *testing.T) { } res, err := client.Treatments( &types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, - "key1", nil, []string{"f1", "f2", "f3"}, Attributes{"a": 1}, client.WithEvaluationOptions(&opts)) + "key1", nil, []string{"f1", "f2", "f3"}, Attributes{"a": 1}, &opts) assert.Nil(t, err) assert.Nil(t, res["f1"].Config) assert.Nil(t, res["f2"].Config) @@ -397,7 +397,7 @@ func TestTreatmentsByFlagSet(t *testing.T) { res, err := client.TreatmentsByFlagSet( &types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, - "key1", nil, "set", Attributes{"a": 1}) + "key1", nil, "set", Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res["f1"].Config) assert.Nil(t, res["f2"].Config) @@ -486,7 +486,7 @@ func TestTreatmentsByFlagSets(t *testing.T) { res, err := client.TreatmentsByFlagSets( &types.ClientConfig{Metadata: types.ClientMetadata{ID: "some", SdkVersion: "go-1.2.3"}}, - "key1", nil, []string{"set_1", "set_2"}, Attributes{"a": 1}) + "key1", nil, []string{"set_1", "set_2"}, Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res["f1"].Config) assert.Nil(t, res["f2"].Config) @@ -555,7 +555,7 @@ func TestImpressionsQueueFull(t *testing.T) { for idx := 0; idx < 4; idx++ { feature := fmt.Sprintf("f%d", idx) expectedImpression.FeatureName = feature - res, err := client.Treatment(clientConf, "key1", nil, feature, Attributes{"a": 1}) + res, err := client.Treatment(clientConf, "key1", nil, feature, Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res.Config) assertImpEq(t, expectedImpression, res.Impression) @@ -567,7 +567,7 @@ func TestImpressionsQueueFull(t *testing.T) { for idx := 4; idx < 8; idx++ { feature := fmt.Sprintf("f%d", idx) expectedImpression.FeatureName = feature - res, err := client.Treatment(clientConf, "key1", nil, feature, Attributes{"a": 1}) + res, err := client.Treatment(clientConf, "key1", nil, feature, Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res.Config) assertImpEq(t, expectedImpression, res.Impression) @@ -577,7 +577,7 @@ func TestImpressionsQueueFull(t *testing.T) { feature := "f8" expectedImpression.FeatureName = feature - res, err := client.Treatment(clientConf, "key1", nil, feature, Attributes{"a": 1}) + res, err := client.Treatment(clientConf, "key1", nil, feature, Attributes{"a": 1}, nil) assert.Nil(t, err) assert.Nil(t, res.Config) assertImpEq(t, expectedImpression, res.Impression)