diff --git a/pkg/iac/scanners/cloudformation/parser/file_context.go b/pkg/iac/scanners/cloudformation/parser/file_context.go index b1b2732eff..7b0bc83183 100644 --- a/pkg/iac/scanners/cloudformation/parser/file_context.go +++ b/pkg/iac/scanners/cloudformation/parser/file_context.go @@ -1,6 +1,11 @@ package parser import ( + "fmt" + "maps" + "regexp" + "strings" + "github.com/samber/lo" "github.com/aquasecurity/trivy/pkg/iac/ignore" @@ -12,6 +17,8 @@ type SourceFormat string const ( YamlSourceFormat SourceFormat = "yaml" JsonSourceFormat SourceFormat = "json" + + ForEachPrefix = "Fn::ForEach::" ) type FileContexts []*FileContext @@ -81,3 +88,168 @@ func (t *FileContext) stripNullProperties() { }) } } + +func (t *FileContext) expandTransforms() error { + resources := make(map[string]*Resource, len(t.Resources)) + + for name, r := range t.Resources { + if r.raw == nil { + resources[name] = r + continue + } + + instances, err := t.expandTransform(r.raw, name) + if err != nil { + return err + } + + for logicalID, rawProp := range instances { + instance, err := newExpandedResource(r, logicalID, rawProp) + if err != nil { + return err + } + resources[logicalID] = instance + } + } + + t.Resources = resources + return nil +} + +func newExpandedResource(base *Resource, logicalID string, raw *Property) (*Resource, error) { + rawMap := raw.AsMap() + typProp, ok := rawMap["Type"] + if !ok { + return nil, fmt.Errorf("missing 'Type' in expanded resource %q", logicalID) + } + propsProp, ok := rawMap["Properties"] + if !ok { + return nil, fmt.Errorf("missing 'Properties' in expanded resource %q", logicalID) + } + + instance := base.clone() + instance.typ = typProp.AsString() + instance.properties = propsProp.AsMap() + instance.setId(logicalID) + return instance, nil +} + +func (t *FileContext) expandTransform(prop *Property, logicalName string) (map[string]*Property, error) { + if strings.HasPrefix(logicalName, "Fn::ForEach::") { + return expandForEach(prop, nil) + } + + return nil, nil +} + +func expandForEach(prop *Property, parentCtx *LoopContext) (map[string]*Property, error) { + + args := prop.AsList() + if len(args) != 3 { + return nil, fmt.Errorf("invalid Fn::ForEach: expected 3 arguments, got %d", len(args)) + } + + identifier := args[0].AsString() + coll := args[1].AsList() + templ := args[2].AsMap() + + result := make(map[string]*Property) + + for _, el := range coll { + loopCtx := parentCtx.Child(identifier, el) + + for tmplKey, templValue := range templ { + cp := templValue.clone() + + // handle nested loop + if strings.HasPrefix(tmplKey, ForEachPrefix) { + nestedResult, err := expandForEach(cp, loopCtx) + if err != nil { + return nil, err + } + maps.Copy(result, nestedResult) + continue + } + + logicalID := resolveLoopPlaceholders(tmplKey, loopCtx) + cp.setLogicalResource(logicalID) + if err := expandProperties(cp, loopCtx); err != nil { + return nil, err + } + + result[logicalID] = cp + } + } + + return result, nil +} + +var placeholderRe = regexp.MustCompile(`[$&]\{([^}]+)\}`) + +func resolveLoopPlaceholders(v string, loopCtx *LoopContext) string { + return placeholderRe.ReplaceAllStringFunc(v, func(s string) string { + id := s[2 : len(s)-1] + val, found := loopCtx.Resolve(id) + if found { + return val.AsString() + } + return s + }) +} + +func expandProperties(prop *Property, parentCtx *LoopContext) error { + prop.loopCtx = parentCtx + + switch v := prop.Value.(type) { + case string: + prop.Value = resolveLoopPlaceholders(v, parentCtx) + case map[string]*Property: + newProps := make(map[string]*Property) + for k, el := range v { + if strings.HasPrefix(k, ForEachPrefix) { + expanded, err := expandForEach(el, parentCtx) + if err != nil { + return err + } + maps.Copy(newProps, expanded) + } else { + if err := expandProperties(el, parentCtx); err != nil { + return err + } + newProps[k] = el + } + } + prop.Value = newProps + case []*Property: + for _, el := range v { + if err := expandProperties(el, parentCtx); err != nil { + return err + } + } + } + return nil +} + +type LoopContext struct { + Identifier string + Value *Property + Parent *LoopContext +} + +func (c *LoopContext) Child(identifier string, value *Property) *LoopContext { + return &LoopContext{ + Identifier: identifier, + Value: value, + Parent: c, + } +} + +func (c *LoopContext) Resolve(name string) (*Property, bool) { + if c.Identifier == name { + return c.Value, true + } + if c.Parent != nil { + return c.Parent.Resolve(name) + } + return nil, false +} diff --git a/pkg/iac/scanners/cloudformation/parser/fn_ref.go b/pkg/iac/scanners/cloudformation/parser/fn_ref.go index afc6ead7cf..8357d4a35b 100644 --- a/pkg/iac/scanners/cloudformation/parser/fn_ref.go +++ b/pkg/iac/scanners/cloudformation/parser/fn_ref.go @@ -4,7 +4,7 @@ import ( "github.com/aquasecurity/trivy/pkg/iac/scanners/cloudformation/cftypes" ) -func ResolveReference(property *Property) (resolved *Property, success bool) { +func ResolveReference(property *Property) (*Property, bool) { if !property.isFunction() { return property, true } @@ -19,16 +19,18 @@ func ResolveReference(property *Property) (resolved *Property, success bool) { return property.deriveResolved(pseudo.t, pseudo.val), true } + if property.loopCtx != nil { + v, found := property.loopCtx.Resolve(refValue) + if found { + return property.deriveResolved(v.Type, v.RawValue()), true + } + } + if property.ctx == nil { return property, false } - var param *Parameter - for k := range property.ctx.Parameters { - if k != refValue { - continue - } - param = property.ctx.Parameters[k] + if param, exists := property.ctx.Parameters[refValue]; exists { resolvedType := param.Type() switch param.Default().(type) { @@ -40,16 +42,14 @@ func ResolveReference(property *Property) (resolved *Property, success bool) { resolvedType = cftypes.Int } - resolved = property.deriveResolved(resolvedType, param.Default()) + resolved := property.deriveResolved(resolvedType, param.Default()) return resolved, true } - for k := range property.ctx.Resources { - if k == refValue { - res := property.ctx.Resources[k] - resolved = property.deriveResolved(cftypes.String, res.ID()) - break - } + if res, exists := property.ctx.Resources[refValue]; exists { + resolved := property.deriveResolved(cftypes.String, res.ID()) + return resolved, true } - return resolved, true + + return nil, false } diff --git a/pkg/iac/scanners/cloudformation/parser/parser.go b/pkg/iac/scanners/cloudformation/parser/parser.go index 8ac14b3270..af60b759c3 100644 --- a/pkg/iac/scanners/cloudformation/parser/parser.go +++ b/pkg/iac/scanners/cloudformation/parser/parser.go @@ -163,6 +163,10 @@ func (p *Parser) ParseFile(ctx context.Context, fsys fs.FS, filePath string) (fc r.configureResource(name, fsys, filePath, fctx) } + if err := fctx.expandTransforms(); err != nil { + return nil, err + } + return fctx, nil } diff --git a/pkg/iac/scanners/cloudformation/parser/parser_test.go b/pkg/iac/scanners/cloudformation/parser/parser_test.go index 9c7e5d1400..f3b410370e 100644 --- a/pkg/iac/scanners/cloudformation/parser/parser_test.go +++ b/pkg/iac/scanners/cloudformation/parser/parser_test.go @@ -482,3 +482,158 @@ Resources: assert.True(t, res.GetProperty("PublicAccessBlockConfiguration.BlockPublicAcls").IsNil()) } + +func Test_ExpandForEachYAML(t *testing.T) { + source := `AWSTemplateFormatVersion: 2010-09-09 +Transform: AWS::LanguageExtensions +Parameters: + TopicNamesParam: + Type: CommaDelimitedList + Default: Success,Failure +Mappings: + Success: + Properties: + DisplayName: success + FifoTopic: "true" + Failure: + Properties: + DisplayName: failure + FifoTopic: "false" +Resources: + 'Fn::ForEach::Topics': + - TopicName + - !Split [",", !Ref TopicNamesParam] + - 'SnsTopic${TopicName}': + Type: 'AWS::SNS::Topic' + Properties: + TopicName: !Sub '${TopicName}.fifo' + 'Fn::ForEach::Properties': + - PropertyName + - [DisplayName, FifoTopic] + - '${PropertyName}': + 'Fn::FindInMap': + - Ref: 'TopicName' + - Properties + - Ref: 'PropertyName' + 'Fn::ForEach::Subscriptions': + - SubName + - ['Alpha', 'Beta'] + - 'SnsSubscription${TopicName}${SubName}': + Type: 'AWS::SNS::Subscription' + Properties: + TopicArn: !Ref 'SnsTopic${TopicName}' + Protocol: email + Endpoint: !Sub '${SubName}@example.com' +` + + files, err := parseFile(t, source, "cf.yaml") + require.NoError(t, err) + file := files[0] + + assert.Len(t, file.Resources, 6) + + tests := []struct { + LogicalID string + Props map[string]any + }{ + // SnsTopic + { + "SnsTopicSuccess", + map[string]any{ + "TopicName": "Success.fifo", + "DisplayName": "success", + "FifoTopic": "true", + }, + }, + { + "SnsTopicFailure", + map[string]any{ + "TopicName": "Failure.fifo", + "DisplayName": "failure", + "FifoTopic": "false", + }, + }, + // SnsSubscription + { + "SnsSubscriptionSuccessAlpha", + map[string]any{ + "TopicArn": "SnsTopicSuccess", + "Protocol": "email", + "Endpoint": "Alpha@example.com", + }, + }, + { + "SnsSubscriptionSuccessBeta", + map[string]any{ + "TopicArn": "SnsTopicSuccess", + "Protocol": "email", + "Endpoint": "Beta@example.com", + }, + }, + { + "SnsSubscriptionFailureAlpha", + map[string]any{ + "TopicArn": "SnsTopicFailure", + "Protocol": "email", + "Endpoint": "Alpha@example.com", + }, + }, + { + "SnsSubscriptionFailureBeta", + map[string]any{ + "TopicArn": "SnsTopicFailure", + "Protocol": "email", + "Endpoint": "Beta@example.com", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.LogicalID, func(t *testing.T) { + res, ok := file.Resources[tt.LogicalID] + require.True(t, ok) + for propName, expected := range tt.Props { + prop := res.GetProperty(propName) + assert.Equal(t, expected, prop.RawValue()) + } + }) + } +} + +func Test_ExpandForEachJSON(t *testing.T) { + source := `{ + "AWSTemplateFormatVersion": "2010-09-09", + "Transform": "AWS::LanguageExtensions", + "Resources": { + "Fn::ForEach::Buckets": [ + "Suffix", + ["A", "B"], + { + "S3Bucket${Suffix}": { + "Type": "AWS::S3::Bucket", + "Properties": { + "BucketName": { "Fn::Sub": "bucket-${Suffix}" } + } + } + } + ] + } + }` + + files, err := parseFile(t, source, "cf.json") + require.NoError(t, err) + require.Len(t, files, 1) + + file := files[0] + require.Len(t, file.Resources, 2) + + b1, ok := file.Resources["S3BucketA"] + require.True(t, ok) + assert.Equal(t, "AWS::S3::Bucket", b1.Type()) + assert.Equal(t, "bucket-A", b1.GetProperty("BucketName").AsString()) + + b2, ok := file.Resources["S3BucketB"] + require.True(t, ok) + assert.Equal(t, "AWS::S3::Bucket", b2.Type()) + assert.Equal(t, "bucket-B", b2.GetProperty("BucketName").AsString()) +} diff --git a/pkg/iac/scanners/cloudformation/parser/property.go b/pkg/iac/scanners/cloudformation/parser/property.go index 9d8a662cdb..42feede738 100644 --- a/pkg/iac/scanners/cloudformation/parser/property.go +++ b/pkg/iac/scanners/cloudformation/parser/property.go @@ -33,10 +33,8 @@ type Property struct { parentRange iacTypes.Range logicalId string unresolved bool -} -func (p *Property) Comment() string { - return p.comment + loopCtx *LoopContext } func (p *Property) setName(name string) { @@ -52,22 +50,17 @@ func (p *Property) setName(name string) { } func (p *Property) setContext(ctx *FileContext) { - p.ctx = ctx + p.walk(func(prop *Property) bool { + prop.ctx = ctx + return true + }) +} - if p.IsMap() { - for _, subProp := range p.AsMap() { - if subProp == nil { - continue - } - subProp.setContext(ctx) - } - } - - if p.IsList() { - for _, subProp := range p.AsList() { - subProp.setContext(ctx) - } - } +func (p *Property) setLogicalResource(id string) { + p.walk(func(prop *Property) bool { + prop.logicalId = id + return !prop.isFunction() + }) } func (p *Property) setFileAndParentRange(target fs.FS, filepath string, parentRange iacTypes.Range) { @@ -80,18 +73,55 @@ func (p *Property) setFileAndParentRange(target fs.FS, filepath string, parentRa if subProp == nil { continue } - subProp.setFileAndParentRange(target, filepath, parentRange) + subProp.setFileAndParentRange(target, filepath, p.rng) } case cftypes.List: for _, subProp := range p.AsList() { if subProp == nil { continue } - subProp.setFileAndParentRange(target, filepath, parentRange) + subProp.setFileAndParentRange(target, filepath, p.rng) } } } +func (p *Property) clone() *Property { + if p == nil { + return nil + } + + clone := &Property{ + Location: p.Location, + ctx: p.ctx, + Type: p.Type, + name: p.name, + comment: p.comment, + rng: p.rng, + parentRange: p.parentRange, + logicalId: p.logicalId, + unresolved: p.unresolved, + } + + switch v := p.Value.(type) { + case map[string]*Property: + m := make(map[string]*Property, len(v)) + for k, el := range v { + m[k] = el.clone() + } + clone.Value = m + case []*Property: + slice := make([]*Property, len(v)) + for i, el := range v { + slice[i] = el.clone() + } + clone.Value = slice + default: + clone.Value = v + } + + return clone +} + func (p *Property) UnmarshalYAML(node *yaml.Node) error { p.StartLine = node.Line p.EndLine = calculateEndLine(node) @@ -301,17 +331,10 @@ func (p *Property) GetProperty(path string) *Property { } func (p *Property) deriveResolved(propType cftypes.CfType, propValue any) *Property { - return &Property{ - Location: p.Location, - Value: propValue, - Type: propType, - ctx: p.ctx, - name: p.name, - comment: p.comment, - rng: p.rng, - parentRange: p.parentRange, - logicalId: p.logicalId, - } + clone := p.clone() + clone.Type = propType + clone.Value = propValue + return clone } func (p *Property) ParentRange() iacTypes.Range { @@ -363,29 +386,6 @@ func (p *Property) String() string { return r } -func (p *Property) setLogicalResource(id string) { - p.logicalId = id - - if p.isFunction() { - return - } - - if p.IsMap() { - for _, subProp := range p.AsMap() { - if subProp == nil { - continue - } - subProp.setLogicalResource(id) - } - } - - if p.IsList() { - for _, subProp := range p.AsList() { - subProp.setLogicalResource(id) - } - } -} - func (p *Property) GetJsonBytes(squashList ...bool) []byte { if p.IsNil() { return []byte{} @@ -458,3 +458,28 @@ func (p *Property) inferType() { } p.Type = typ } + +func (p *Property) walk(fn func(*Property) bool) { + if fn == nil { + return + } + + if !fn(p) { + return + } + + switch v := p.Value.(type) { + case map[string]*Property: + for _, child := range v { + if child != nil { + child.walk(fn) + } + } + case []*Property: + for _, child := range v { + if child != nil { + child.walk(fn) + } + } + } +} diff --git a/pkg/iac/scanners/cloudformation/parser/resource.go b/pkg/iac/scanners/cloudformation/parser/resource.go index 35b8e6d92e..6b93b90cd6 100644 --- a/pkg/iac/scanners/cloudformation/parser/resource.go +++ b/pkg/iac/scanners/cloudformation/parser/resource.go @@ -3,6 +3,7 @@ package parser import ( "encoding/json/jsontext" "encoding/json/v2" + "fmt" "io/fs" "strings" @@ -20,6 +21,8 @@ type Resource struct { rng iacTypes.Range id string comment string + + raw *Property } func (r *Resource) configureResource(id string, target fs.FS, filepath string, ctx *FileContext) { @@ -46,13 +49,39 @@ func (r *Resource) setFile(target fs.FS, filepath string) { func (r *Resource) setContext(ctx *FileContext) { r.ctx = ctx - - for _, p := range r.properties { - p.setLogicalResource(r.id) - p.setContext(ctx) + if r.raw != nil { + r.raw.setContext(ctx) + } else { + for _, p := range r.properties { + p.setLogicalResource(r.id) + p.setContext(ctx) + } } } +func (r *Resource) clone() *Resource { + clone := &Resource{ + typ: r.typ, + ctx: r.ctx, + rng: r.rng, + id: r.id, + comment: r.comment, + } + + if r.properties != nil { + clone.properties = make(map[string]*Property, len(r.properties)) + for k, p := range r.properties { + clone.properties[k] = p.clone() + } + } + + if r.raw != nil { + clone.raw = r.raw.clone() + } + + return clone +} + type resourceInner struct { Type string `json:"Type" yaml:"Type"` Properties map[string]*Property `json:"Properties" yaml:"Properties"` @@ -63,22 +92,43 @@ func (r *Resource) UnmarshalYAML(node *yaml.Node) error { r.EndLine = calculateEndLine(node) r.comment = node.LineComment - var i resourceInner - if err := node.Decode(&i); err != nil { - return err + switch node.Kind { + case yaml.MappingNode: + var i resourceInner + if err := node.Decode(&i); err != nil { + return err + } + r.typ = i.Type + r.properties = i.Properties + return nil + case yaml.SequenceNode: + var raw Property + if err := node.Decode(&raw); err != nil { + return err + } + r.raw = &raw + return nil + default: + return fmt.Errorf("unsupported YAML node kind: %v", node.Kind) } - r.typ = i.Type - r.properties = i.Properties - return nil } func (r *Resource) UnmarshalJSONFrom(dec *jsontext.Decoder) error { - var i resourceInner - if err := json.UnmarshalDecode(dec, &i); err != nil { - return err + switch dec.PeekKind() { + case '{': + var i resourceInner + if err := json.UnmarshalDecode(dec, &i); err != nil { + return err + } + r.typ = i.Type + r.properties = i.Properties + case '[': + var raw Property + if err := json.UnmarshalDecode(dec, &raw); err != nil { + return err + } + r.raw = &raw } - r.typ = i.Type - r.properties = i.Properties return nil }