Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions package_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2749,6 +2749,104 @@ func main() {
`)
}

func TestForRangeFunc0(t *testing.T) {
// Test range over func(yield func() bool) - 0 values
pkg := newMainPackage()
// Define iterator function type: func(yield func() bool)
yieldSig := types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool])), false)
yieldParam := pkg.NewParam(token.NoPos, "yield", yieldSig)
iterSig := types.NewTuple(yieldParam)

// Define foo function
pkg.NewFunc(nil, "foo", iterSig, nil, false).BodyStart(pkg).
End()

pkg.NewFunc(nil, "main", nil, nil, false).BodyStart(pkg).
ForRange().Val(ctxRef(pkg, "foo")).RangeAssignThen(token.NoPos).
Val(pkg.Import("fmt").Ref("Println")).Val("Hi").Call(1).EndStmt().
End().
End()
domTest(t, pkg, `package main

import "fmt"

func foo(yield func() bool) {
}
func main() {
for range foo {
fmt.Println("Hi")
}
}
`)
}

func TestForRangeFunc1(t *testing.T) {
// Test range over func(yield func(V) bool) - 1 value
pkg := newMainPackage()
// Define iterator function type: func(yield func(string) bool)
yieldParamV := types.NewVar(token.NoPos, nil, "v", types.Typ[types.String])
yieldRet := types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool])
yieldSig := types.NewSignatureType(nil, nil, nil, types.NewTuple(yieldParamV), types.NewTuple(yieldRet), false)
iterParam := pkg.NewParam(token.NoPos, "yield", yieldSig)
iterSig := types.NewTuple(iterParam)

// Define bar function
pkg.NewFunc(nil, "bar", iterSig, nil, false).BodyStart(pkg).
End()

pkg.NewFunc(nil, "main", nil, nil, false).BodyStart(pkg).
ForRange("v").Val(ctxRef(pkg, "bar")).RangeAssignThen(token.NoPos).
Val(pkg.Import("fmt").Ref("Println")).Val(ctxRef(pkg, "v")).Call(1).EndStmt().
End().
End()
domTest(t, pkg, `package main

import "fmt"

func bar(yield func(v string) bool) {
}
func main() {
for v := range bar {
fmt.Println(v)
}
}
`)
}

func TestForRangeFunc2(t *testing.T) {
// Test range over func(yield func(K, V) bool) - 2 values (key-value pairs)
pkg := newMainPackage()
// Define iterator function type: func(yield func(string, int) bool)
yieldParamK := types.NewVar(token.NoPos, nil, "k", types.Typ[types.String])
yieldParamV := types.NewVar(token.NoPos, nil, "v", types.Typ[types.Int])
yieldRet := types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool])
yieldSig := types.NewSignatureType(nil, nil, nil, types.NewTuple(yieldParamK, yieldParamV), types.NewTuple(yieldRet), false)
iterParam := pkg.NewParam(token.NoPos, "yield", yieldSig)
iterSig := types.NewTuple(iterParam)

// Define weekdays function
pkg.NewFunc(nil, "weekdays", iterSig, nil, false).BodyStart(pkg).
End()

pkg.NewFunc(nil, "main", nil, nil, false).BodyStart(pkg).
ForRange("k", "v").Val(ctxRef(pkg, "weekdays")).RangeAssignThen(token.NoPos).
Val(pkg.Import("fmt").Ref("Println")).Val(ctxRef(pkg, "k")).Val(ctxRef(pkg, "v")).Call(2).EndStmt().
End().
End()
domTest(t, pkg, `package main

import "fmt"

func weekdays(yield func(k string, v int) bool) {
}
func main() {
for k, v := range weekdays {
fmt.Println(k, v)
}
}
`)
}

func TestReturn(t *testing.T) {
pkg := newMainPackage()
format := pkg.NewParam(token.NoPos, "format", types.Typ[types.String])
Expand Down
49 changes: 49 additions & 0 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,14 @@ retry:
if (t.Info() & types.IsString) != 0 {
return []types.Type{types.Typ[types.Int], types.Typ[types.Rune]}
}
case *types.Signature:
// Go 1.23 range over function types:
// func(yield func() bool) - 0 values
// func(yield func(V) bool) - 1 value
// func(yield func(K, V) bool) - 2 values
if kvt := checkIteratorFunc(t); kvt != nil {
return kvt
}
case *types.Named:
if kv, ok := p.checkUdt(cb, t); ok {
return kv
Expand All @@ -618,6 +626,47 @@ retry:
return nil
}

// checkIteratorFunc checks if sig is a Go 1.23 iterator function signature.
// Returns [keyType, valType] if valid, nil otherwise.
// For 0-value iterators, returns empty slice.
// For 1-value iterators, returns [valType, nil].
// For 2-value iterators, returns [keyType, valType].
func checkIteratorFunc(sig *types.Signature) []types.Type {
// Must have no results
if sig.Results().Len() != 0 {
return nil
}
// Must have exactly 1 parameter (the yield function)
if sig.Params().Len() != 1 {
return nil
}
// The parameter must be a function
yieldSig, ok := sig.Params().At(0).Type().(*types.Signature)
if !ok {
return nil
}
// yield must return bool
if yieldSig.Results().Len() != 1 {
return nil
}
retType := yieldSig.Results().At(0).Type()
basic, ok := retType.(*types.Basic)
if !ok || basic.Kind() != types.Bool {
return nil
}
// Check yield parameters (0, 1, or 2)
n := yieldSig.Params().Len()
switch n {
case 0:
return []types.Type{nil, nil}
case 1:
return []types.Type{yieldSig.Params().At(0).Type(), nil}
case 2:
return []types.Type{yieldSig.Params().At(0).Type(), yieldSig.Params().At(1).Type()}
}
return nil
}

func (p *forRangeStmt) checkUdt(cb *CodeBuilder, o *types.Named) ([]types.Type, bool) {
if enumName, sig := findEnumMethodType(cb, o); sig != nil {
p.enumName = enumName
Expand Down
Loading