diff --git a/index/circular_reference_result.go b/index/circular_reference_result.go index 95270b7..a710d6a 100644 --- a/index/circular_reference_result.go +++ b/index/circular_reference_result.go @@ -9,15 +9,21 @@ type CircularReferenceResult struct { LoopIndex int LoopPoint *Reference IsPolymorphicResult bool // if this result comes from a polymorphic loop. + IsInfiniteLoop bool // if all the definitions in the reference loop are marked as required, this is an infinite circular reference, thus is not allowed. } func (c *CircularReferenceResult) GenerateJourneyPath() string { buf := strings.Builder{} for i, ref := range c.Journey { - buf.WriteString(ref.Name) - if i+1 < len(c.Journey) { + if i > 0 { buf.WriteString(" -> ") } + + buf.WriteString(ref.Name) + // buf.WriteString(" (") + // buf.WriteString(ref.Definition) + // buf.WriteString(")") } + return buf.String() } diff --git a/index/spec_index.go b/index/spec_index.go index 341f19c..8e5779d 100644 --- a/index/spec_index.go +++ b/index/spec_index.go @@ -35,16 +35,17 @@ const ( // Reference is a wrapper around *yaml.Node results to make things more manageable when performing // algorithms on data models. the *yaml.Node def is just a bit too low level for tracking state. type Reference struct { - Definition string - Name string - Node *yaml.Node - ParentNode *yaml.Node - Resolved bool - Circular bool - Seen bool - IsRemote bool - RemoteLocation string - Path string // this won't always be available. + Definition string + Name string + Node *yaml.Node + ParentNode *yaml.Node + Resolved bool + Circular bool + Seen bool + IsRemote bool + RemoteLocation string + Path string // this won't always be available. + RequiredRefProperties map[string][]string // definition names (eg, #/definitions/One) to a list of required properties on this definition which reference that definition } // ReferenceMapped is a helper struct for mapped references put into sequence (we lose the key) @@ -1614,10 +1615,12 @@ func (index *SpecIndex) FindComponent(componentId string, parent *yaml.Node) *Re return nil } + // FIXME: This is a potential security hole, and needs to be made optional (see log4j fiasco) remoteLookup := func(id string) (*yaml.Node, *yaml.Node, error) { return index.lookupRemoteReference(id) } + // FIXME: As above fileLookup := func(id string) (*yaml.Node, *yaml.Node, error) { return index.lookupFileReference(id) } @@ -1698,18 +1701,93 @@ func (index *SpecIndex) extractDefinitionsAndSchemas(schemasNode *yaml.Node, pat name = schema.Value continue } + def := fmt.Sprintf("%s%s", pathPrefix, name) ref := &Reference{ - Definition: def, - Name: name, - Node: schema, - Path: fmt.Sprintf("$.components.schemas.%s", name), - ParentNode: schemasNode, + Definition: def, + Name: name, + Node: schema, + Path: fmt.Sprintf("$.components.schemas.%s", name), + ParentNode: schemasNode, + RequiredRefProperties: index.extractDefinitionRequiredRefProperties(schemasNode, map[string][]string{}), } index.allSchemas[def] = ref } } +// extractDefinitionRequiredRefProperties goes through the direct properties of a schema and extracts the map of required definitions from within it +func (index *SpecIndex) extractDefinitionRequiredRefProperties(schemaNode *yaml.Node, reqRefProps map[string][]string) map[string][]string { + if schemaNode == nil { + return reqRefProps + } + + _, requiredSeqNode := utils.FindKeyNode("required", schemaNode.Content) + if requiredSeqNode == nil { + return reqRefProps + } + + _, propertiesMapNode := utils.FindKeyNode("properties", schemaNode.Content) + if propertiesMapNode == nil { + // TODO: Log a warning on the resolver, because if you have required properties, but no actual properties, something is wrong + return reqRefProps + } + + name := "" + for i, param := range propertiesMapNode.Content { + if i%2 == 0 { + name = param.Value + continue + } + + _, paramPropertiesMapNode := utils.FindKeyNode("properties", param.Content) + if paramPropertiesMapNode != nil { + reqRefProps = index.extractDefinitionRequiredRefProperties(param, reqRefProps) + } + + for _, key := range []string{"allOf", "oneOf", "anyOf"} { + _, ofNode := utils.FindKeyNode(key, param.Content) + if ofNode != nil { + for _, ofNodeItem := range ofNode.Content { + reqRefProps = index.extractRequiredReferenceProperties(ofNodeItem, name, reqRefProps) + } + } + } + } + + for _, requiredPropertyNode := range requiredSeqNode.Content { + _, requiredPropDefNode := utils.FindKeyNode(requiredPropertyNode.Value, propertiesMapNode.Content) + if requiredPropDefNode == nil { + continue + } + + reqRefProps = index.extractRequiredReferenceProperties(requiredPropDefNode, requiredPropertyNode.Value, reqRefProps) + } + + return reqRefProps +} + +// extractRequiredReferenceProperties returns a map of definition names to the property or properties which reference it within a node +func (index *SpecIndex) extractRequiredReferenceProperties(requiredPropDefNode *yaml.Node, propName string, reqRefProps map[string][]string) map[string][]string { + isRef, _, defPath := utils.IsNodeRefValue(requiredPropDefNode) + if !isRef { + _, defItems := utils.FindKeyNode("items", requiredPropDefNode.Content) + if defItems != nil { + isRef, _, defPath = utils.IsNodeRefValue(defItems) + } + } + + if /* still */ !isRef { + return reqRefProps + } + + if _, ok := reqRefProps[defPath]; !ok { + reqRefProps[defPath] = []string{} + } + reqRefProps[defPath] = append(reqRefProps[defPath], propName) + + return reqRefProps +} + func (index *SpecIndex) extractComponentParameters(paramsNode *yaml.Node, pathPrefix string) { var name string for i, param := range paramsNode.Content { @@ -1899,11 +1977,13 @@ func (index *SpecIndex) FindComponentInRoot(componentId string) *Reference { res, _ := path.Find(index.root) if len(res) == 1 { ref := &Reference{ - Definition: componentId, - Name: name, - Node: res[0], - Path: friendlySearch, + Definition: componentId, + Name: name, + Node: res[0], + Path: friendlySearch, + RequiredRefProperties: index.extractDefinitionRequiredRefProperties(res[0], map[string][]string{}), } + return ref } } diff --git a/resolver/resolver.go b/resolver/resolver.go index 6f37c68..ddd45b3 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -5,6 +5,7 @@ package resolver import ( "fmt" + "github.com/pb33f/libopenapi/index" "github.com/pb33f/libopenapi/utils" "gopkg.in/yaml.v3" @@ -87,7 +88,6 @@ func (resolver *Resolver) GetNonPolymorphicCircularErrors() []*index.CircularRef // re-organize the node tree. Make sure you have copied your original tree before running this (if you want to preserve // original data) func (resolver *Resolver) Resolve() []*ResolvingError { - mapped := resolver.specIndex.GetMappedReferencesSequenced() mappedIndex := resolver.specIndex.GetMappedReferences() @@ -98,7 +98,6 @@ func (resolver *Resolver) Resolve() []*ResolvingError { } schemas := resolver.specIndex.GetAllSchemas() - for s, schemaRef := range schemas { if mappedIndex[s] == nil { seenReferences := make(map[string]bool) @@ -118,8 +117,13 @@ func (resolver *Resolver) Resolve() []*ResolvingError { } for _, circRef := range resolver.circularReferences { + // If the circular reference is not required, we can ignore it, as it's a terminable loop rather than an infinite one + if !circRef.IsInfiniteLoop { + continue + } + resolver.resolvingErrors = append(resolver.resolvingErrors, &ResolvingError{ - ErrorRef: fmt.Errorf("Circular reference detected: %s", circRef.Start.Name), + ErrorRef: fmt.Errorf("Infinite circular reference detected: %s", circRef.Start.Name), Node: circRef.LoopPoint.Node, Path: circRef.GenerateJourneyPath(), }) @@ -130,7 +134,6 @@ func (resolver *Resolver) Resolve() []*ResolvingError { // CheckForCircularReferences Check for circular references, without resolving, a non-destructive run. func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError { - mapped := resolver.specIndex.GetMappedReferencesSequenced() mappedIndex := resolver.specIndex.GetMappedReferences() for _, ref := range mapped { @@ -138,6 +141,7 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError { var journey []*index.Reference resolver.VisitReference(ref.Reference, seenReferences, journey, false) } + schemas := resolver.specIndex.GetAllSchemas() for s, schemaRef := range schemas { if mappedIndex[s] == nil { @@ -146,9 +150,15 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError { resolver.VisitReference(schemaRef, seenReferences, journey, false) } } + for _, circRef := range resolver.circularReferences { + // If the circular reference is not required, we can ignore it, as it's a terminable loop rather than an infinite one + if !circRef.IsInfiniteLoop { + continue + } + resolver.resolvingErrors = append(resolver.resolvingErrors, &ResolvingError{ - ErrorRef: fmt.Errorf("Circular reference detected: %s", circRef.Start.Name), + ErrorRef: fmt.Errorf("Infinite circular reference detected: %s", circRef.Start.Name), Node: circRef.LoopPoint.Node, Path: circRef.GenerateJourneyPath(), CircularReference: circRef, @@ -161,7 +171,6 @@ func (resolver *Resolver) CheckForCircularReferences() []*ResolvingError { // VisitReference will visit a reference as part of a journey and will return resolved nodes. func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]bool, journey []*index.Reference, resolve bool) []*yaml.Node { - if ref.Resolved || ref.Seen { return ref.Node.Content } @@ -173,34 +182,32 @@ func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]b seen[ref.Definition] = true for _, r := range relatives { - // check if we have seen this on the journey before, if so! it's circular skip := false for i, j := range journey { if j.Definition == r.Definition { - foundDup := resolver.specIndex.GetMappedReferences()[r.Definition] var circRef *index.CircularReferenceResult if !foundDup.Circular { - loop := append(journey, foundDup) + circRef = &index.CircularReferenceResult{ - Journey: loop, - Start: foundDup, - LoopIndex: i, - LoopPoint: foundDup, + Journey: loop, + Start: foundDup, + LoopIndex: i, + LoopPoint: foundDup, + IsInfiniteLoop: resolver.isInfiniteCircularDependency(foundDup, nil), } + resolver.circularReferences = append(resolver.circularReferences, circRef) foundDup.Seen = true foundDup.Circular = true - resolver.circularReferences = append(resolver.circularReferences, circRef) - } skip = true - } } + if !skip { original := resolver.specIndex.GetMappedReferences()[r.Definition] resolved := resolver.VisitReference(original, seen, journey, resolve) @@ -217,6 +224,30 @@ func (resolver *Resolver) VisitReference(ref *index.Reference, seen map[string]b return ref.Node.Content } +func (resolver *Resolver) isInfiniteCircularDependency(ref *index.Reference, initialRef *index.Reference) bool { + if ref == nil { + return false + } + + for refDefinition := range ref.RequiredRefProperties { + r := resolver.specIndex.GetMappedReferences()[refDefinition] + if initialRef != nil && initialRef.Definition == r.Definition { + return true + } + + ir := initialRef + if ir == nil { + ir = ref + } + + if resolver.isInfiniteCircularDependency(r, ir) { + return true + } + } + + return false +} + func (resolver *Resolver) extractRelatives(node *yaml.Node, foundRelatives map[string]bool, journey []*index.Reference, resolve bool) []*index.Reference { diff --git a/utils/utils.go b/utils/utils.go index 395579d..da028b2 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,12 +3,13 @@ package utils import ( "encoding/json" "fmt" - "github.com/vmware-labs/yaml-jsonpath/pkg/yamlpath" - "gopkg.in/yaml.v3" "net/url" "regexp" "strconv" "strings" + + "github.com/vmware-labs/yaml-jsonpath/pkg/yamlpath" + "gopkg.in/yaml.v3" ) type Case int8 @@ -402,15 +403,50 @@ func IsNodeRefValue(node *yaml.Node) (bool, *yaml.Node, string) { return false, nil, "" } +// IsPropertyNodeRequired will check if a node is required within circular references +func IsPropertyNodeRequired(node *yaml.Node, propertyName string) bool { + _, requiredSeqNode := FindKeyNode("required", node.Content) + if requiredSeqNode == nil { + return false + } + + _, propertiesMapNode := FindKeyNode("properties", node.Content) + if propertiesMapNode == nil { + return false + } + + for _, requiredPropertyNode := range requiredSeqNode.Content { + _, requiredPropDefNode := FindKeyNode(requiredPropertyNode.Value, propertiesMapNode.Content) + if requiredPropDefNode == nil { + continue + } + + isRef, _, defPath := IsNodeRefValue(requiredPropDefNode) + if isRef && defPath == propertyName { + return true + } + + _, defItems := FindKeyNode("items", requiredPropDefNode.Content) + if defItems == nil { + continue + } + + isRef, _, defPath = IsNodeRefValue(defItems) + if isRef && defPath == propertyName { + return true + } + } + + return false +} + // FixContext will clean up a JSONpath string to be correctly traversable. func FixContext(context string) string { - tokens := strings.Split(context, ".") var cleaned = []string{} + for i, t := range tokens { - if v, err := strconv.Atoi(t); err == nil { - if v < 200 { // codes start here if cleaned[i-1] != "" { cleaned[i-1] += fmt.Sprintf("[%v]", t) @@ -421,8 +457,8 @@ func FixContext(context string) string { continue } cleaned = append(cleaned, strings.ReplaceAll(t, "(root)", "$")) - } + return strings.Join(cleaned, ".") }