diff --git a/datamodel/low/base/schema.go b/datamodel/low/base/schema.go index 9b2fccf..56213f6 100644 --- a/datamodel/low/base/schema.go +++ b/datamodel/low/base/schema.go @@ -1254,10 +1254,12 @@ func ExtractSchema(ctx context.Context, root *yaml.Node, idx *index.SpecIndex) ( if rf, rl, _ := utils.IsNodeRefValue(root); rf { // locate reference in index. isRef = true - ref, _, _ := low.LocateRefNode(root, idx) + ref, fIdx, _, nCtx := low.LocateRefNodeWithContext(ctx, root, idx) if ref != nil { schNode = ref schLabel = rl + ctx = nCtx + idx = fIdx } else { return nil, fmt.Errorf(errStr, root.Content[1].Value, root.Content[1].Line, root.Content[1].Column) diff --git a/datamodel/low/extraction_functions.go b/datamodel/low/extraction_functions.go index 4268213..6998f47 100644 --- a/datamodel/low/extraction_functions.go +++ b/datamodel/low/extraction_functions.go @@ -352,12 +352,14 @@ func ExtractArray[T Buildable[N], N any](ctx context.Context, label string, root var ln, vn *yaml.Node var circError error root = utils.NodeAlias(root) + isRef := false if rf, rl, _ := utils.IsNodeRefValue(root); rf { - ref, fIdx, err, nCtx := LocateRefNodeWithContext(ctx, root, idx) + ref, fIdx, err, nCtx := LocateRefEnd(ctx, root, idx, 0) if ref != nil { + isRef = true vn = ref ln = rl - fIdx = fIdx + idx = fIdx ctx = nCtx if err != nil { circError = err @@ -370,8 +372,9 @@ func ExtractArray[T Buildable[N], N any](ctx context.Context, label string, root _, ln, vn = utils.FindKeyNodeFullTop(label, root.Content) if vn != nil { if h, _, _ := utils.IsNodeRefValue(vn); h { - ref, fIdx, err, nCtx := LocateRefNodeWithContext(ctx, vn, idx) + ref, fIdx, err, nCtx := LocateRefEnd(ctx, vn, idx, 0) if ref != nil { + isRef = true vn = ref idx = fIdx ctx = nCtx @@ -381,8 +384,9 @@ func ExtractArray[T Buildable[N], N any](ctx context.Context, label string, root } } else { if err != nil { - return []ValueReference[T]{}, nil, nil, fmt.Errorf("array build failed: reference cannot be found: %s", - err.Error()) + return []ValueReference[T]{}, nil, nil, + fmt.Errorf("array build failed: reference cannot be found: %s", + err.Error()) } } } @@ -392,7 +396,22 @@ func ExtractArray[T Buildable[N], N any](ctx context.Context, label string, root var items []ValueReference[T] if vn != nil && ln != nil { if !utils.IsNodeArray(vn) { - return []ValueReference[T]{}, nil, nil, fmt.Errorf("array build failed, input is not an array, line %d, column %d", vn.Line, vn.Column) + + if !isRef { + return []ValueReference[T]{}, nil, nil, + fmt.Errorf("array build failed, input is not an array, line %d, column %d", vn.Line, vn.Column) + } + // if this was pulled from a ref, but it's not a sequence, check the label and see if anything comes out, + // and then check that is a sequence, if not, fail it. + _, _, fvn := utils.FindKeyNodeFullTop(label, vn.Content) + if fvn != nil { + if !utils.IsNodeArray(vn) { + return []ValueReference[T]{}, nil, nil, + fmt.Errorf("array build failed, input is not an array, line %d, column %d", vn.Line, vn.Column) + } else { + vn = fvn + } + } } for _, node := range vn.Content { localReferenceValue := "" @@ -402,7 +421,7 @@ func ExtractArray[T Buildable[N], N any](ctx context.Context, label string, root foundIndex := idx if rf, _, rv := utils.IsNodeRefValue(node); rf { - refg, fIdx, err, nCtx := LocateRefNodeWithContext(ctx, node, idx) + refg, fIdx, err, nCtx := LocateRefEnd(ctx, node, idx, 0) if refg != nil { node = refg //localIsReference = true @@ -601,11 +620,13 @@ func ExtractMapExtensions[PT Buildable[N], N any]( root = utils.NodeAlias(root) if rf, rl, rv := utils.IsNodeRefValue(root); rf { // locate reference in index. - ref, _, err := LocateRefNode(root, idx) + ref, fIdx, err, fCtx := LocateRefNodeWithContext(ctx, root, idx) if ref != nil { valueNode = ref labelNode = rl referenceValue = rv + ctx = fCtx + idx = fIdx if err != nil { circError = err } @@ -834,3 +855,27 @@ func GenerateHashString(v any) string { } return fmt.Sprintf(HASH, sha256.Sum256([]byte(fmt.Sprint(v)))) } + +func LocateRefEnd(ctx context.Context, root *yaml.Node, idx *index.SpecIndex, depth int) (*yaml.Node, *index.SpecIndex, error, context.Context) { + depth++ + if depth > 100 { + return nil, nil, fmt.Errorf("reference resolution depth exceeded, possible circular reference"), ctx + } + ref, fIdx, err, nCtx := LocateRefNodeWithContext(ctx, root, idx) + if err != nil { + return ref, fIdx, err, nCtx + } + if ref != nil { + if rf, _, _ := utils.IsNodeRefValue(ref); rf { + return LocateRefEnd(nCtx, ref, fIdx, depth) + } else { + return ref, fIdx, err, nCtx + } + } else { + if root.Content[1].Value == "" { + return nil, nil, fmt.Errorf("reference at line %d, column %d is empty, it cannot be resolved", + root.Content[1].Line, root.Content[1].Column), ctx + } + return nil, nil, fmt.Errorf("reference cannot be found: %s", root.Content[1].Value), ctx + } +} diff --git a/datamodel/low/extraction_functions_test.go b/datamodel/low/extraction_functions_test.go index 3403f60..76316e2 100644 --- a/datamodel/low/extraction_functions_test.go +++ b/datamodel/low/extraction_functions_test.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "fmt" "golang.org/x/sync/syncmap" + "gopkg.in/yaml.v3" "net/url" "os" "strings" @@ -15,7 +16,6 @@ import ( "github.com/pb33f/libopenapi/index" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v3" ) func TestFindItemInMap(t *testing.T) { @@ -748,7 +748,7 @@ func TestExtractArray_Ref_Circular(t *testing.T) { things, _, _, err := ExtractArray[*test_Good](context.Background(), "", cNode.Content[0], idx) assert.Error(t, err) - assert.Len(t, things, 0) + assert.Len(t, things, 2) } func TestExtractArray_Ref_Bad(t *testing.T) { @@ -890,7 +890,7 @@ func TestExtractArray_Ref_Nested_CircularFlat(t *testing.T) { assert.NoError(t, e) things, _, _, err := ExtractArray[*test_Good](context.Background(), "limes", cNode.Content[0], idx) assert.Error(t, err) - assert.Len(t, things, 0) + assert.Len(t, things, 2) } func TestExtractArray_BadBuild(t *testing.T) { @@ -1950,7 +1950,7 @@ func TestLocateRefNode_DoARealLookup(t *testing.T) { // fake cache to a lookup for a file that does not exist will work. fakeCache := new(syncmap.Map) - fakeCache.Store("/root.yaml#/components/schemas/Burger", &index.Reference{Node: &no}) + fakeCache.Store("/root.yaml#/components/schemas/Burger", &index.Reference{Node: &no, Index: idx}) idx.SetCache(fakeCache) ctx := context.WithValue(context.Background(), index.CurrentPathKey, "/root.yaml#/components/schemas/Burger") diff --git a/index/resolver.go b/index/resolver.go index 4b074d6..fe8c8e0 100644 --- a/index/resolver.go +++ b/index/resolver.go @@ -265,6 +265,16 @@ func visitIndex(res *Resolver, idx *SpecIndex) { } } + schemas = idx.GetAllSecuritySchemes() + for s, schemaRef := range schemas { + if mappedIndex[s] == nil { + seenReferences := make(map[string]bool) + var journey []*Reference + res.journeysTaken++ + schemaRef.Node.Content = res.VisitReference(schemaRef, seenReferences, journey, true) + } + } + // map everything for _, sequenced := range idx.GetAllSequencedReferences() { locatedDef := mappedIndex[sequenced.Definition] @@ -279,7 +289,12 @@ func visitIndex(res *Resolver, idx *SpecIndex) { // VisitReference will visit a reference as part of a journey and will return resolved nodes. func (resolver *Resolver) VisitReference(ref *Reference, seen map[string]bool, journey []*Reference, resolve bool) []*yaml.Node { resolver.referencesVisited++ - if ref.Resolved || ref.Seen { + if resolve && ref.Seen { + if ref.Resolved { + return ref.Node.Content + } + } + if !resolve && ref.Seen { return ref.Node.Content } @@ -342,13 +357,15 @@ func (resolver *Resolver) VisitReference(ref *Reference, seen map[string]bool, j } resolved := resolver.VisitReference(original, seen, journey, resolve) if resolve && !original.Circular { + ref.Resolved = true + r.Resolved = true r.Node.Content = resolved // this is where we perform the actual resolving. } r.Seen = true ref.Seen = true } } - ref.Resolved = true + ref.Seen = true return ref.Node.Content @@ -521,6 +538,10 @@ func (resolver *Resolver) extractRelatives(ref *Reference, node, parent *yaml.No continue } + if resolve { + ref.Node = locatedRef.Node + } + schemaType := "" if parent != nil { _, arrayTypevn := utils.FindKeyNodeTop("type", parent.Content) diff --git a/index/resolver_test.go b/index/resolver_test.go index 597aea4..c9d16c5 100644 --- a/index/resolver_test.go +++ b/index/resolver_test.go @@ -532,7 +532,7 @@ components: assert.NotNil(t, resolver) err := resolver.Resolve() - assert.Len(t, err, 1) + assert.Len(t, err, 2) assert.Equal(t, "cannot resolve reference `go home, I am drunk`, it's missing: $go home, I am drunk [18:11]", err[0].Error()) } diff --git a/index/search_index.go b/index/search_index.go index cb3277e..16d107c 100644 --- a/index/search_index.go +++ b/index/search_index.go @@ -34,7 +34,8 @@ func (index *SpecIndex) SearchIndexForReferenceWithContext(ctx context.Context, func (index *SpecIndex) SearchIndexForReferenceByReferenceWithContext(ctx context.Context, searchRef *Reference) (*Reference, *SpecIndex, context.Context) { if v, ok := index.cache.Load(searchRef.FullDefinition); ok { - return v.(*Reference), index, context.WithValue(ctx, CurrentPathKey, v.(*Reference).RemoteLocation) + //return v.(*Reference), index, context.WithValue(ctx, CurrentPathKey, v.(*Reference).RemoteLocation) + return v.(*Reference), v.(*Reference).Index, context.WithValue(ctx, CurrentPathKey, v.(*Reference).RemoteLocation) } ref := searchRef.FullDefinition @@ -163,6 +164,16 @@ func (index *SpecIndex) SearchIndexForReferenceByReferenceWithContext(ctx contex } } } + } else { + if r, ok := index.allMappedRefs[ref]; ok { + index.cache.Store(ref, r) + return r, r.Index, context.WithValue(ctx, CurrentPathKey, r.RemoteLocation) + } + + if r, ok := index.allMappedRefs[refAlt]; ok { + index.cache.Store(refAlt, r) + return r, r.Index, context.WithValue(ctx, CurrentPathKey, r.RemoteLocation) + } } if index.logger != nil { diff --git a/index/utility_methods.go b/index/utility_methods.go index 0930502..48a710c 100644 --- a/index/utility_methods.go +++ b/index/utility_methods.go @@ -308,19 +308,24 @@ func (index *SpecIndex) extractComponentExamples(examplesNode *yaml.Node, pathPr } func (index *SpecIndex) extractComponentSecuritySchemes(securitySchemesNode *yaml.Node, pathPrefix string) { + var name string - for i, secScheme := range securitySchemesNode.Content { + for i, schema := range securitySchemesNode.Content { if i%2 == 0 { - name = secScheme.Value + name = schema.Value continue } def := fmt.Sprintf("%s%s", pathPrefix, name) + fullDef := fmt.Sprintf("%s%s", index.specAbsolutePath, def) + ref := &Reference{ - Definition: def, - Name: name, - Node: secScheme, - ParentNode: securitySchemesNode, - Path: fmt.Sprintf("$.components.securitySchemes.%s", name), + FullDefinition: fullDef, + Definition: def, + Name: name, + Node: schema, + Path: fmt.Sprintf("$.components.securitySchemes.%s", name), + ParentNode: securitySchemesNode, + RequiredRefProperties: extractDefinitionRequiredRefProperties(securitySchemesNode, map[string][]string{}, fullDef, index), } index.allSecuritySchemes[def] = ref } @@ -340,6 +345,16 @@ func (index *SpecIndex) countUniqueInlineDuplicates() int { return unique } +func seekRefEnd(index *SpecIndex, refName string) *Reference { + ref, _ := index.SearchIndexForReference(refName) + if ref != nil { + if ok, _, v := utils.IsNodeRefValue(ref.Node); ok { + return seekRefEnd(ref.Index, v) + } + } + return ref +} + func (index *SpecIndex) scanOperationParams(params []*yaml.Node, pathItemNode *yaml.Node, method string) { for i, param := range params { // param is ref @@ -349,7 +364,7 @@ func (index *SpecIndex) scanOperationParams(params []*yaml.Node, pathItemNode *y paramRef := index.allMappedRefs[paramRefName] if paramRef == nil { // could be in the rolodex - ref, _ := index.SearchIndexForReference(paramRefName) + ref := seekRefEnd(index, paramRefName) if ref != nil { paramRef = ref }