How to timeout tests in Go

3 minute read Published: 2019-07-22

Today I want to share a nice trick I found while inspecting the test file for the bufio in the standard library.

It is about how we can make a test case fail if it take more than X seconds, where X is an arbitrary value we chose.

To give some context I was working with some code that, with some input conditions it would never finish. After having identified and fixed the bug I started to wonder how could I create a test case to prevent this bug to ever return to our codebase...

The code I was working with made use of the bufio.Reader struct, so it made sense to me read the test code for this struct. And that's how I found this trick.

The idea is really simple, you just need to call the function you want to test in a separate goroutine and send the response through a channel. After that you create a select statement where you can define the timeout you want.

To make it more clear I will create a very contrived example of how this techique can be used. First imagine that you want to test a function that receives an array of int and return the sum as an int:

func sum(values []int) int {
    result := 0
    for _, v := values {
        result += v
    }
    return result
}

Creating some test cases for this function is pretty simple, especially since we are using GoUnit to generate the boilerplate:

func Test_sum(t *testing.T) {
	type args struct {
		values []int
	}
	tests := []struct {
		name string
		args func(t *testing.T) args

		want1 int
	}{
		{
			name: "1, 2, 3, 4 should sums to 10",
			args: func(*testing.T) args {
				return args{values: []int{1, 2, 3, 4}}
			},
			want1: 10,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			tArgs := tt.args(t)

			got1 := sum(tArgs.values)

			if !reflect.DeepEqual(got1, tt.want1) {
				t.Errorf("sum got1 = %v, want1: %v", got1, tt.want1)
			}
		})
	}
}

Now comes the fun stuff. With a simple change we can make sure this test will fail if it take more than 100 ms.

for _, tt := range tests {
    t.Run(tt.name, func(t *testing.T) {

        responses := make(chan int)
        go func() {
            tArgs := tt.args(t)
            r := sum(tArgs.values)
            responses <- r
        }()

        select {
        case got1 := <-responses:
            if !reflect.DeepEqual(got1, tt.want1) {
                t.Errorf("sum got1 = %v, want1: %v", got1, tt.want1)
            }
        case <-time.After(100 * time.Millisecond):
            t.Error("test timed out")
        }
    })
}

Keep in mind that it would be a really bad idea use this to create some kind of tests against performance regressions.