diff --git a/internal/config/mongo_database_test.go b/internal/config/mongo_database_test.go new file mode 100644 index 0000000..c7b2fd9 --- /dev/null +++ b/internal/config/mongo_database_test.go @@ -0,0 +1,116 @@ +package config + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMongoDatabaseEnvironmentVariable(t *testing.T) { + defer func() { + // Clean up + os.Unsetenv("DECKARD_MONGODB_DATABASE") + os.Unsetenv("DECKARD_MONGO_DATABASE") + }() + + // Clean environment + os.Unsetenv("DECKARD_MONGODB_DATABASE") + os.Unsetenv("DECKARD_MONGO_DATABASE") + + // Test default value + Configure(true) + require.Equal(t, "deckard", MongoDatabase.Get()) + + // Test with DECKARD_MONGODB_DATABASE environment variable + os.Setenv("DECKARD_MONGODB_DATABASE", "custom_db_1") + Configure(true) + require.Equal(t, "custom_db_1", MongoDatabase.Get()) + + // Test with DECKARD_MONGO_DATABASE environment variable + os.Unsetenv("DECKARD_MONGODB_DATABASE") + os.Setenv("DECKARD_MONGO_DATABASE", "custom_db_2") + Configure(true) + require.Equal(t, "custom_db_2", MongoDatabase.Get()) +} + +func TestMongoCollectionEnvironmentVariable(t *testing.T) { + defer func() { + // Clean up + os.Unsetenv("DECKARD_MONGODB_COLLECTION") + os.Unsetenv("DECKARD_MONGO_COLLECTION") + }() + + // Clean environment + os.Unsetenv("DECKARD_MONGODB_COLLECTION") + os.Unsetenv("DECKARD_MONGO_COLLECTION") + + // Test default value + Configure(true) + require.Equal(t, "queue", MongoCollection.Get()) + + // Test with DECKARD_MONGODB_COLLECTION environment variable + os.Setenv("DECKARD_MONGODB_COLLECTION", "custom_collection") + Configure(true) + require.Equal(t, "custom_collection", MongoCollection.Get()) + + // Test with DECKARD_MONGO_COLLECTION environment variable + os.Unsetenv("DECKARD_MONGODB_COLLECTION") + os.Setenv("DECKARD_MONGO_COLLECTION", "custom_collection_2") + Configure(true) + require.Equal(t, "custom_collection_2", MongoCollection.Get()) +} + +func TestMongoBooleanEnvironmentVariable(t *testing.T) { + defer func() { + // Clean up + os.Unsetenv("DECKARD_MONGODB_SSL") + os.Unsetenv("DECKARD_MONGO_SSL") + }() + + // Clean environment + os.Unsetenv("DECKARD_MONGODB_SSL") + os.Unsetenv("DECKARD_MONGO_SSL") + + // Test default value + Configure(true) + require.Equal(t, false, MongoSsl.GetBool()) + + // Test with DECKARD_MONGODB_SSL environment variable + os.Setenv("DECKARD_MONGODB_SSL", "true") + Configure(true) + require.Equal(t, true, MongoSsl.GetBool()) + + // Test with DECKARD_MONGO_SSL environment variable + os.Unsetenv("DECKARD_MONGODB_SSL") + os.Setenv("DECKARD_MONGO_SSL", "true") + Configure(true) + require.Equal(t, true, MongoSsl.GetBool()) +} + +func TestMongoIntegerEnvironmentVariable(t *testing.T) { + defer func() { + // Clean up + os.Unsetenv("DECKARD_MONGODB_MAX_POOL_SIZE") + os.Unsetenv("DECKARD_MONGO_MAX_POOL_SIZE") + }() + + // Clean environment + os.Unsetenv("DECKARD_MONGODB_MAX_POOL_SIZE") + os.Unsetenv("DECKARD_MONGO_MAX_POOL_SIZE") + + // Test default value + Configure(true) + require.Equal(t, 100, MongoMaxPoolSize.GetInt()) + + // Test with DECKARD_MONGODB_MAX_POOL_SIZE environment variable + os.Setenv("DECKARD_MONGODB_MAX_POOL_SIZE", "250") + Configure(true) + require.Equal(t, 250, MongoMaxPoolSize.GetInt()) + + // Test with DECKARD_MONGO_MAX_POOL_SIZE environment variable + os.Unsetenv("DECKARD_MONGODB_MAX_POOL_SIZE") + os.Setenv("DECKARD_MONGO_MAX_POOL_SIZE", "300") + Configure(true) + require.Equal(t, 300, MongoMaxPoolSize.GetInt()) +} diff --git a/internal/config/viper_config.go b/internal/config/viper_config.go index 1d54118..899dc4c 100644 --- a/internal/config/viper_config.go +++ b/internal/config/viper_config.go @@ -32,80 +32,88 @@ func (config *ViperConfigKey) Set(value any) { } } -// Should never be called before config is initialized using config.Configure() -func (config *ViperConfigKey) Get() string { +// getWithFallback is a helper that implements the common logic for all getter methods. +// It checks the main key and aliases for values different from the default, returning +// the first override found, or the default if no overrides exist. +func getWithFallback[T comparable](config *ViperConfigKey, defaultVal T, keyGetter func(string) T) T { + // Check main key - if it differs from default, use it (environment variable takes precedence) if viper.IsSet(config.Key) { - return viper.GetString(config.Key) + keyVal := keyGetter(config.Key) + if keyVal != defaultVal { + return keyVal + } } + // Check aliases - if any differs from default, use it (environment variable takes precedence) for _, alias := range config.GetAliases() { if viper.IsSet(alias) { - return viper.GetString(alias) + aliasVal := keyGetter(alias) + if aliasVal != defaultVal { + return aliasVal + } } } + // Return default value + return defaultVal +} + +// Should never be called before config is initialized using config.Configure() +func (config *ViperConfigKey) Get() string { + defaultVal := "" if val, ok := config.GetDefault().(string); ok { - return val + defaultVal = val } - return "" + return getWithFallback(config, defaultVal, viper.GetString) } // Should never be called before config is initialized using config.Configure() func (config *ViperConfigKey) GetDuration() time.Duration { - if viper.IsSet(config.Key) { - return viper.GetDuration(config.Key) - } - - for _, alias := range config.GetAliases() { - if viper.IsSet(alias) { - return viper.GetDuration(alias) + defaultVal := time.Duration(0) + if val, ok := config.GetDefault().(string); ok { + parsed, err := time.ParseDuration(val) + if err == nil { + defaultVal = parsed } } - if val, ok := config.GetDefault().(string); ok { - duration, _ := time.ParseDuration(val) - - return duration + // Use a custom getter that ensures consistent parsing behavior + getDuration := func(key string) time.Duration { + if viper.IsSet(key) { + // Try viper's built-in parsing first + if duration := viper.GetDuration(key); duration != 0 { + return duration + } + // Fall back to manual parsing if viper returns 0 but key is set + if str := viper.GetString(key); str != "" { + if parsed, err := time.ParseDuration(str); err == nil { + return parsed + } + } + } + return 0 } - return 0 + return getWithFallback(config, defaultVal, getDuration) } // Should never be called before config is initialized using config.Configure() func (config *ViperConfigKey) GetBool() bool { - if viper.IsSet(config.Key) { - return viper.GetBool(config.Key) - } - - for _, alias := range config.GetAliases() { - if viper.IsSet(alias) { - return viper.GetBool(alias) - } - } - + defaultVal := false if val, ok := config.GetDefault().(bool); ok { - return val + defaultVal = val } - return false + return getWithFallback(config, defaultVal, viper.GetBool) } // Should never be called before config is initialized using config.Configure() func (config *ViperConfigKey) GetInt() int { - if viper.IsSet(config.Key) { - return viper.GetInt(config.Key) - } - - for _, alias := range config.GetAliases() { - if viper.IsSet(alias) { - return viper.GetInt(alias) - } - } - + defaultVal := 0 if val, ok := config.GetDefault().(int); ok { - return val + defaultVal = val } - return 0 + return getWithFallback(config, defaultVal, viper.GetInt) } diff --git a/internal/service/cert/.gitignore b/internal/service/cert/.gitignore index 612424a..35697cb 100644 --- a/internal/service/cert/.gitignore +++ b/internal/service/cert/.gitignore @@ -1 +1,2 @@ -*.pem \ No newline at end of file +*.pem +*.srl \ No newline at end of file