feat(cloudformation): add support for Fn::ForEach (#9508)

Signed-off-by: nikpivkin <nikita.pivkin@smartforce.io>
This commit is contained in:
Nikita Pivkin
2025-12-12 00:53:03 +06:00
committed by GitHub
parent 1a901e5c75
commit d65b504cb2
6 changed files with 490 additions and 84 deletions

View File

@@ -1,6 +1,11 @@
package parser
import (
"fmt"
"maps"
"regexp"
"strings"
"github.com/samber/lo"
"github.com/aquasecurity/trivy/pkg/iac/ignore"
@@ -12,6 +17,8 @@ type SourceFormat string
const (
YamlSourceFormat SourceFormat = "yaml"
JsonSourceFormat SourceFormat = "json"
ForEachPrefix = "Fn::ForEach::"
)
type FileContexts []*FileContext
@@ -81,3 +88,168 @@ func (t *FileContext) stripNullProperties() {
})
}
}
func (t *FileContext) expandTransforms() error {
resources := make(map[string]*Resource, len(t.Resources))
for name, r := range t.Resources {
if r.raw == nil {
resources[name] = r
continue
}
instances, err := t.expandTransform(r.raw, name)
if err != nil {
return err
}
for logicalID, rawProp := range instances {
instance, err := newExpandedResource(r, logicalID, rawProp)
if err != nil {
return err
}
resources[logicalID] = instance
}
}
t.Resources = resources
return nil
}
func newExpandedResource(base *Resource, logicalID string, raw *Property) (*Resource, error) {
rawMap := raw.AsMap()
typProp, ok := rawMap["Type"]
if !ok {
return nil, fmt.Errorf("missing 'Type' in expanded resource %q", logicalID)
}
propsProp, ok := rawMap["Properties"]
if !ok {
return nil, fmt.Errorf("missing 'Properties' in expanded resource %q", logicalID)
}
instance := base.clone()
instance.typ = typProp.AsString()
instance.properties = propsProp.AsMap()
instance.setId(logicalID)
return instance, nil
}
func (t *FileContext) expandTransform(prop *Property, logicalName string) (map[string]*Property, error) {
if strings.HasPrefix(logicalName, "Fn::ForEach::") {
return expandForEach(prop, nil)
}
return nil, nil
}
func expandForEach(prop *Property, parentCtx *LoopContext) (map[string]*Property, error) {
args := prop.AsList()
if len(args) != 3 {
return nil, fmt.Errorf("invalid Fn::ForEach: expected 3 arguments, got %d", len(args))
}
identifier := args[0].AsString()
coll := args[1].AsList()
templ := args[2].AsMap()
result := make(map[string]*Property)
for _, el := range coll {
loopCtx := parentCtx.Child(identifier, el)
for tmplKey, templValue := range templ {
cp := templValue.clone()
// handle nested loop
if strings.HasPrefix(tmplKey, ForEachPrefix) {
nestedResult, err := expandForEach(cp, loopCtx)
if err != nil {
return nil, err
}
maps.Copy(result, nestedResult)
continue
}
logicalID := resolveLoopPlaceholders(tmplKey, loopCtx)
cp.setLogicalResource(logicalID)
if err := expandProperties(cp, loopCtx); err != nil {
return nil, err
}
result[logicalID] = cp
}
}
return result, nil
}
var placeholderRe = regexp.MustCompile(`[$&]\{([^}]+)\}`)
func resolveLoopPlaceholders(v string, loopCtx *LoopContext) string {
return placeholderRe.ReplaceAllStringFunc(v, func(s string) string {
id := s[2 : len(s)-1]
val, found := loopCtx.Resolve(id)
if found {
return val.AsString()
}
return s
})
}
func expandProperties(prop *Property, parentCtx *LoopContext) error {
prop.loopCtx = parentCtx
switch v := prop.Value.(type) {
case string:
prop.Value = resolveLoopPlaceholders(v, parentCtx)
case map[string]*Property:
newProps := make(map[string]*Property)
for k, el := range v {
if strings.HasPrefix(k, ForEachPrefix) {
expanded, err := expandForEach(el, parentCtx)
if err != nil {
return err
}
maps.Copy(newProps, expanded)
} else {
if err := expandProperties(el, parentCtx); err != nil {
return err
}
newProps[k] = el
}
}
prop.Value = newProps
case []*Property:
for _, el := range v {
if err := expandProperties(el, parentCtx); err != nil {
return err
}
}
}
return nil
}
type LoopContext struct {
Identifier string
Value *Property
Parent *LoopContext
}
func (c *LoopContext) Child(identifier string, value *Property) *LoopContext {
return &LoopContext{
Identifier: identifier,
Value: value,
Parent: c,
}
}
func (c *LoopContext) Resolve(name string) (*Property, bool) {
if c.Identifier == name {
return c.Value, true
}
if c.Parent != nil {
return c.Parent.Resolve(name)
}
return nil, false
}

View File

@@ -4,7 +4,7 @@ import (
"github.com/aquasecurity/trivy/pkg/iac/scanners/cloudformation/cftypes"
)
func ResolveReference(property *Property) (resolved *Property, success bool) {
func ResolveReference(property *Property) (*Property, bool) {
if !property.isFunction() {
return property, true
}
@@ -19,16 +19,18 @@ func ResolveReference(property *Property) (resolved *Property, success bool) {
return property.deriveResolved(pseudo.t, pseudo.val), true
}
if property.loopCtx != nil {
v, found := property.loopCtx.Resolve(refValue)
if found {
return property.deriveResolved(v.Type, v.RawValue()), true
}
}
if property.ctx == nil {
return property, false
}
var param *Parameter
for k := range property.ctx.Parameters {
if k != refValue {
continue
}
param = property.ctx.Parameters[k]
if param, exists := property.ctx.Parameters[refValue]; exists {
resolvedType := param.Type()
switch param.Default().(type) {
@@ -40,16 +42,14 @@ func ResolveReference(property *Property) (resolved *Property, success bool) {
resolvedType = cftypes.Int
}
resolved = property.deriveResolved(resolvedType, param.Default())
resolved := property.deriveResolved(resolvedType, param.Default())
return resolved, true
}
for k := range property.ctx.Resources {
if k == refValue {
res := property.ctx.Resources[k]
resolved = property.deriveResolved(cftypes.String, res.ID())
break
}
if res, exists := property.ctx.Resources[refValue]; exists {
resolved := property.deriveResolved(cftypes.String, res.ID())
return resolved, true
}
return resolved, true
return nil, false
}

View File

@@ -163,6 +163,10 @@ func (p *Parser) ParseFile(ctx context.Context, fsys fs.FS, filePath string) (fc
r.configureResource(name, fsys, filePath, fctx)
}
if err := fctx.expandTransforms(); err != nil {
return nil, err
}
return fctx, nil
}

View File

@@ -482,3 +482,158 @@ Resources:
assert.True(t, res.GetProperty("PublicAccessBlockConfiguration.BlockPublicAcls").IsNil())
}
func Test_ExpandForEachYAML(t *testing.T) {
source := `AWSTemplateFormatVersion: 2010-09-09
Transform: AWS::LanguageExtensions
Parameters:
TopicNamesParam:
Type: CommaDelimitedList
Default: Success,Failure
Mappings:
Success:
Properties:
DisplayName: success
FifoTopic: "true"
Failure:
Properties:
DisplayName: failure
FifoTopic: "false"
Resources:
'Fn::ForEach::Topics':
- TopicName
- !Split [",", !Ref TopicNamesParam]
- 'SnsTopic${TopicName}':
Type: 'AWS::SNS::Topic'
Properties:
TopicName: !Sub '${TopicName}.fifo'
'Fn::ForEach::Properties':
- PropertyName
- [DisplayName, FifoTopic]
- '${PropertyName}':
'Fn::FindInMap':
- Ref: 'TopicName'
- Properties
- Ref: 'PropertyName'
'Fn::ForEach::Subscriptions':
- SubName
- ['Alpha', 'Beta']
- 'SnsSubscription${TopicName}${SubName}':
Type: 'AWS::SNS::Subscription'
Properties:
TopicArn: !Ref 'SnsTopic${TopicName}'
Protocol: email
Endpoint: !Sub '${SubName}@example.com'
`
files, err := parseFile(t, source, "cf.yaml")
require.NoError(t, err)
file := files[0]
assert.Len(t, file.Resources, 6)
tests := []struct {
LogicalID string
Props map[string]any
}{
// SnsTopic
{
"SnsTopicSuccess",
map[string]any{
"TopicName": "Success.fifo",
"DisplayName": "success",
"FifoTopic": "true",
},
},
{
"SnsTopicFailure",
map[string]any{
"TopicName": "Failure.fifo",
"DisplayName": "failure",
"FifoTopic": "false",
},
},
// SnsSubscription
{
"SnsSubscriptionSuccessAlpha",
map[string]any{
"TopicArn": "SnsTopicSuccess",
"Protocol": "email",
"Endpoint": "Alpha@example.com",
},
},
{
"SnsSubscriptionSuccessBeta",
map[string]any{
"TopicArn": "SnsTopicSuccess",
"Protocol": "email",
"Endpoint": "Beta@example.com",
},
},
{
"SnsSubscriptionFailureAlpha",
map[string]any{
"TopicArn": "SnsTopicFailure",
"Protocol": "email",
"Endpoint": "Alpha@example.com",
},
},
{
"SnsSubscriptionFailureBeta",
map[string]any{
"TopicArn": "SnsTopicFailure",
"Protocol": "email",
"Endpoint": "Beta@example.com",
},
},
}
for _, tt := range tests {
t.Run(tt.LogicalID, func(t *testing.T) {
res, ok := file.Resources[tt.LogicalID]
require.True(t, ok)
for propName, expected := range tt.Props {
prop := res.GetProperty(propName)
assert.Equal(t, expected, prop.RawValue())
}
})
}
}
func Test_ExpandForEachJSON(t *testing.T) {
source := `{
"AWSTemplateFormatVersion": "2010-09-09",
"Transform": "AWS::LanguageExtensions",
"Resources": {
"Fn::ForEach::Buckets": [
"Suffix",
["A", "B"],
{
"S3Bucket${Suffix}": {
"Type": "AWS::S3::Bucket",
"Properties": {
"BucketName": { "Fn::Sub": "bucket-${Suffix}" }
}
}
}
]
}
}`
files, err := parseFile(t, source, "cf.json")
require.NoError(t, err)
require.Len(t, files, 1)
file := files[0]
require.Len(t, file.Resources, 2)
b1, ok := file.Resources["S3BucketA"]
require.True(t, ok)
assert.Equal(t, "AWS::S3::Bucket", b1.Type())
assert.Equal(t, "bucket-A", b1.GetProperty("BucketName").AsString())
b2, ok := file.Resources["S3BucketB"]
require.True(t, ok)
assert.Equal(t, "AWS::S3::Bucket", b2.Type())
assert.Equal(t, "bucket-B", b2.GetProperty("BucketName").AsString())
}

View File

@@ -33,10 +33,8 @@ type Property struct {
parentRange iacTypes.Range
logicalId string
unresolved bool
}
func (p *Property) Comment() string {
return p.comment
loopCtx *LoopContext
}
func (p *Property) setName(name string) {
@@ -52,22 +50,17 @@ func (p *Property) setName(name string) {
}
func (p *Property) setContext(ctx *FileContext) {
p.ctx = ctx
p.walk(func(prop *Property) bool {
prop.ctx = ctx
return true
})
}
if p.IsMap() {
for _, subProp := range p.AsMap() {
if subProp == nil {
continue
}
subProp.setContext(ctx)
}
}
if p.IsList() {
for _, subProp := range p.AsList() {
subProp.setContext(ctx)
}
}
func (p *Property) setLogicalResource(id string) {
p.walk(func(prop *Property) bool {
prop.logicalId = id
return !prop.isFunction()
})
}
func (p *Property) setFileAndParentRange(target fs.FS, filepath string, parentRange iacTypes.Range) {
@@ -80,18 +73,55 @@ func (p *Property) setFileAndParentRange(target fs.FS, filepath string, parentRa
if subProp == nil {
continue
}
subProp.setFileAndParentRange(target, filepath, parentRange)
subProp.setFileAndParentRange(target, filepath, p.rng)
}
case cftypes.List:
for _, subProp := range p.AsList() {
if subProp == nil {
continue
}
subProp.setFileAndParentRange(target, filepath, parentRange)
subProp.setFileAndParentRange(target, filepath, p.rng)
}
}
}
func (p *Property) clone() *Property {
if p == nil {
return nil
}
clone := &Property{
Location: p.Location,
ctx: p.ctx,
Type: p.Type,
name: p.name,
comment: p.comment,
rng: p.rng,
parentRange: p.parentRange,
logicalId: p.logicalId,
unresolved: p.unresolved,
}
switch v := p.Value.(type) {
case map[string]*Property:
m := make(map[string]*Property, len(v))
for k, el := range v {
m[k] = el.clone()
}
clone.Value = m
case []*Property:
slice := make([]*Property, len(v))
for i, el := range v {
slice[i] = el.clone()
}
clone.Value = slice
default:
clone.Value = v
}
return clone
}
func (p *Property) UnmarshalYAML(node *yaml.Node) error {
p.StartLine = node.Line
p.EndLine = calculateEndLine(node)
@@ -301,17 +331,10 @@ func (p *Property) GetProperty(path string) *Property {
}
func (p *Property) deriveResolved(propType cftypes.CfType, propValue any) *Property {
return &Property{
Location: p.Location,
Value: propValue,
Type: propType,
ctx: p.ctx,
name: p.name,
comment: p.comment,
rng: p.rng,
parentRange: p.parentRange,
logicalId: p.logicalId,
}
clone := p.clone()
clone.Type = propType
clone.Value = propValue
return clone
}
func (p *Property) ParentRange() iacTypes.Range {
@@ -363,29 +386,6 @@ func (p *Property) String() string {
return r
}
func (p *Property) setLogicalResource(id string) {
p.logicalId = id
if p.isFunction() {
return
}
if p.IsMap() {
for _, subProp := range p.AsMap() {
if subProp == nil {
continue
}
subProp.setLogicalResource(id)
}
}
if p.IsList() {
for _, subProp := range p.AsList() {
subProp.setLogicalResource(id)
}
}
}
func (p *Property) GetJsonBytes(squashList ...bool) []byte {
if p.IsNil() {
return []byte{}
@@ -458,3 +458,28 @@ func (p *Property) inferType() {
}
p.Type = typ
}
func (p *Property) walk(fn func(*Property) bool) {
if fn == nil {
return
}
if !fn(p) {
return
}
switch v := p.Value.(type) {
case map[string]*Property:
for _, child := range v {
if child != nil {
child.walk(fn)
}
}
case []*Property:
for _, child := range v {
if child != nil {
child.walk(fn)
}
}
}
}

View File

@@ -3,6 +3,7 @@ package parser
import (
"encoding/json/jsontext"
"encoding/json/v2"
"fmt"
"io/fs"
"strings"
@@ -20,6 +21,8 @@ type Resource struct {
rng iacTypes.Range
id string
comment string
raw *Property
}
func (r *Resource) configureResource(id string, target fs.FS, filepath string, ctx *FileContext) {
@@ -46,13 +49,39 @@ func (r *Resource) setFile(target fs.FS, filepath string) {
func (r *Resource) setContext(ctx *FileContext) {
r.ctx = ctx
for _, p := range r.properties {
p.setLogicalResource(r.id)
p.setContext(ctx)
if r.raw != nil {
r.raw.setContext(ctx)
} else {
for _, p := range r.properties {
p.setLogicalResource(r.id)
p.setContext(ctx)
}
}
}
func (r *Resource) clone() *Resource {
clone := &Resource{
typ: r.typ,
ctx: r.ctx,
rng: r.rng,
id: r.id,
comment: r.comment,
}
if r.properties != nil {
clone.properties = make(map[string]*Property, len(r.properties))
for k, p := range r.properties {
clone.properties[k] = p.clone()
}
}
if r.raw != nil {
clone.raw = r.raw.clone()
}
return clone
}
type resourceInner struct {
Type string `json:"Type" yaml:"Type"`
Properties map[string]*Property `json:"Properties" yaml:"Properties"`
@@ -63,22 +92,43 @@ func (r *Resource) UnmarshalYAML(node *yaml.Node) error {
r.EndLine = calculateEndLine(node)
r.comment = node.LineComment
var i resourceInner
if err := node.Decode(&i); err != nil {
return err
switch node.Kind {
case yaml.MappingNode:
var i resourceInner
if err := node.Decode(&i); err != nil {
return err
}
r.typ = i.Type
r.properties = i.Properties
return nil
case yaml.SequenceNode:
var raw Property
if err := node.Decode(&raw); err != nil {
return err
}
r.raw = &raw
return nil
default:
return fmt.Errorf("unsupported YAML node kind: %v", node.Kind)
}
r.typ = i.Type
r.properties = i.Properties
return nil
}
func (r *Resource) UnmarshalJSONFrom(dec *jsontext.Decoder) error {
var i resourceInner
if err := json.UnmarshalDecode(dec, &i); err != nil {
return err
switch dec.PeekKind() {
case '{':
var i resourceInner
if err := json.UnmarshalDecode(dec, &i); err != nil {
return err
}
r.typ = i.Type
r.properties = i.Properties
case '[':
var raw Property
if err := json.UnmarshalDecode(dec, &raw); err != nil {
return err
}
r.raw = &raw
}
r.typ = i.Type
r.properties = i.Properties
return nil
}