adrianhesketh.com

Mocking AWS SDK calls in Go

Today, one of my colleagues asked me how to unit test Go code that uses the AWS SDK. There’s nothing particularly unique about the AWS SDK - it’s similar to any 3rd party library where the results returned can be different depending on a variety of external circumstances. The network connection to the AWS endpoints might be down, or there might be a firewall rule in place that doesn’t allow traffic, or the call might fail due to a lack of AWS credentials.

When testing, I want to make sure that in the situation where there has been a problem with the AWS SDK call, that my code recovers or exits gracefully and, logs the problem details for analysis.

I often start by building a simple command line tool to carry out integration testing to make sure I’ve written the AWS SDK code correctly before I integrate it into my program logic. For example, I’d write a very simple CLI tool to push a message to a Kinesis stream, or upload a file to an S3 bucket. That way, I know that the AWS specific code works correctly:

package main

import (
	"fmt"
	"os"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/s3"
)

func main() {
	region := "eu-west-2"
	bucket := "my-bucket"
	key := "test.txt"
	body, err := os.Open("test.txt")
	if err != nil {
		fmt.Printf("error opening file: %v\n", err)
		os.Exit(1)
	}

	sess, err := session.NewSession(&aws.Config{
		Region: aws.String(region)},
	)
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	svc := s3.New(sess)
	_, err = svc.PutObject(&s3.PutObjectInput{
		Bucket: aws.String(bucket),
		Key:    aws.String(key),
		Body:   body,
	})
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	return
}

My next step is separate out the variables. Some of the variables are required for the AWS service to operate, while others are important elements of using the service. For example, if I’m uploading data to an S3 bucket, the AWS region and the bucket name are required for the service to operate, while the key name and the data to upload are variable.

So, really, I need a function with a signature of PutFile(key string, data io.ReadSeeker) error).

I can now write my application logic, allowing any function that can put a file somewhere to be passed in as a value.

func writeRandomFile(putFile func(key string, data io.ReadSeeker) error, length int, name string) error {
	data := make([]byte, length)
	_, err := rand.Read(data)
	if err != nil {
		return fmt.Errorf("writeRandomFile: failed to create random data: %v", err)
	}
	err = putFile(name, bytes.NewReader(data))
	if err != nil {
		return fmt.Errorf("writeRandomFile: failed to write file %v: %v", name, err)
	}
	return nil
}

The putFile signature is a bit verbose, so if you don’t mind the abstraction, you can give a function signature a type name. Here, I’ve given func(key string, data io.ReadSeeker) the name filePutter:

type filePutter func(key string, data io.ReadSeeker) error

func writeRandomFile(putFile filePutter, length int, name string) error {
	data := make([]byte, length)
	_, err := rand.Read(data)
	if err != nil {
		return fmt.Errorf("writeRandomFile: failed to create random data: %v", err)
	}
	err = putFile(name, bytes.NewReader(data))
	if err != nil {
		return fmt.Errorf("writeRandomFile: failed to write file %v: %v", name, err)
	}
	return nil
}

I can then start writing unit tests by passing this function to my application logic. I’ll start by checking that if writing the file fails, an error is returned from writeRandomFile.

func TestWritingFailsReturnsAnError(t *testing.T) {
	mockFailure := func(key string, data io.ReadSeeker) error {
		return errors.New("failed to read for some reason")
	}
	err := writeRandomFile(mockFailure, 10, "filename")
	if err == nil {
		t.Errorf("expected an error when writing files fails")
	}
	expectedErr := "writeRandomFile: failed to write file \"filename\": failed to read for some reason"
	if err.Error() != expectedErr {
		t.Errorf("expected error: %v, got: %v", expectedErr, err)
	}
}

Now, I can plug in the AWS S3 implementation of the filePutter function to the real program:

func main() {
	if err := writeRandomFile(putS3, 10, "filename"); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	fmt.Println("OK")
}

func putS3(key string, data io.ReadSeeker) error {
	region := "eu-west-2"
	bucket := "my-bucket"
	sess, err := session.NewSession(&aws.Config{
		Region: aws.String(region)},
	)
	if err != nil {
		return err
	}
	svc := s3.New(sess)
	_, err = svc.PutObject(&s3.PutObjectInput{
		Bucket: aws.String(bucket),
		Key:    aws.String(key),
		Body:   data,
	})
	return err
}

type filePutter func(key string, data io.ReadSeeker) error

func writeRandomFile(putFile filePutter, length int, name string) error {
	data := make([]byte, length)
	_, err := rand.Read(data)
	if err != nil {
		return fmt.Errorf("writeRandomFile: failed to create random data: %v", err)
	}
	err = putFile(name, bytes.NewReader(data))
	if err != nil {
		return fmt.Errorf("writeRandomFile: failed to write file %q: %v", name, err)
	}
	return nil
}

So, now I have something I can test (writeRandomFile), but also something that I can use (putS3 passed to writeRandomFile). There is a problem in that the region and bucket are hard coded to “eu-west-2” and “my-bucket”, and they’re exactly the sorts of things that change between test and production envrionments. It wouldn’t be a good design to add them to the function signature (e.g. putS3(region, bucket, key string, data io.ReadSeeker)) because they’re really an implementation detail that writeRandomFile shouldn’t care about, so I need another solution.

There are two basic ways to pass variables to a function without them being passed as function parameters. One way is to use a closure, that is to create the variables and return the function:

func main() {
	s3Putter := newS3Putter("eu-west-2", "my-bucket")
	if err := writeRandomFile(s3Putter, 10, "filename"); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	fmt.Println("OK")
}

func newS3Putter(region, bucket string) filePutter {
	return func(key string, data io.ReadSeeker) error {
		sess, err := session.NewSession(&aws.Config{
			Region: aws.String(region)},
		)
		if err != nil {
			return err
		}
		svc := s3.New(sess)
		_, err = svc.PutObject(&s3.PutObjectInput{
			Bucket: aws.String(bucket),
			Key:    aws.String(key),
			Body:   data,
		})
		return err
	}
}

The other way is to create a struct with the required fields and use a method:

func main() {
	s3Putter := s3Putter{
		region: "eu-west-2",
		bucket: "my-bucket",
	}
	if err := writeRandomFile(s3Putter.Put, 10, "filename"); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	fmt.Println("OK")
}

type s3Putter struct {
	region string
	bucket string
}

func (s3p s3Putter) Put(key string, data io.ReadSeeker) error {
	sess, err := session.NewSession(&aws.Config{
		Region: aws.String(s3p.region)},
	)
	if err != nil {
		return err
	}
	svc := s3.New(sess)
	_, err = svc.PutObject(&s3.PutObjectInput{
		Bucket: aws.String(s3p.bucket),
		Key:    aws.String(key),
		Body:   data,
	})
	return err
}

Personally, I like closures, they don’t create a new type when we don’t really need one. But, if instead of using a function signature to allow for alternative implementations of a function, you use an interface, then structs are a much more natural way to implement that interface. I usually avoid using interfaces because they’re more effort to implement during unit testing. A function is so simple, whereas implementing an interface requires defining a struct (there’s a clever way where a function can implement an interface, but it’s not something I’ve seen people doing outside of http.Handler).

func main() {
	s3Putter := s3Putter{
		region: "eu-west-2",
		bucket: "my-bucket",
	}
	if err := writeRandomFile(s3Putter, 10, "filename"); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	fmt.Println("OK")
}

type s3Putter struct {
	region string
	bucket string
}

func (s3p s3Putter) Put(key string, data io.ReadSeeker) error {
	sess, err := session.NewSession(&aws.Config{
		Region: aws.String(s3p.region)},
	)
	if err != nil {
		return err
	}
	svc := s3.New(sess)
	_, err = svc.PutObject(&s3.PutObjectInput{
		Bucket: aws.String(s3p.bucket),
		Key:    aws.String(key),
		Body:   data,
	})
	return err
}

type filePutter interface {
	Put(key string, data io.ReadSeeker) error
}

func writeRandomFile(putter filePutter, length int, name string) error {
	data := make([]byte, length)
	_, err := rand.Read(data)
	if err != nil {
		return fmt.Errorf("writeRandomFile: failed to create random data: %v", err)
	}
	err = putter.Put(name, bytes.NewReader(data))
	if err != nil {
		return fmt.Errorf("writeRandomFile: failed to write file %q: %v", name, err)
	}
	return nil
}

This idea of passing alternative implementations into other code is called “Dependency Injection”. The dependencies are being injected into the code. However, it’s really annoying for your code’s users if they have to understand and wire up all of your dependencies just because you wanted to do some unit testing, so you still need to consider how end users initialise your code. You might want to put some sensible default implementations in place, so that it just works. You can use this pattern to leave your options open later: https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis

I don’t mind dependencies being wired up in main because it’s exactly where that should happen, but if writeRandomFile is going to be a dependency of something else (e.g. used within a HTTP handler), the HTTP handler shouldn’t need to know about filePutter, because that’s an internal detail of the writeRandomFile functionality.

The good news is that we know the two ways that we can take this implementation detail away from the function signature - using a closure, or using a struct. Now we can stack the dependencies up:

func main() {
	// Pass the s3Putter into the writer.
	s3Putter := newS3Putter("eu-west-2", "my-bucket")
	writer := newWriter(s3Putter)

	// Pass the writer into the handler.
	handler := handle(writer)

	// Pass the handler into the Web server mux.
	http.Handle("/", handler)

	// Start the Web server.
	http.ListenAndServe(":8080", nil)
}

func handle(writeRandom func(length int, name string) error) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		err := writeRandom(10, "filename")
		if err != nil {
			http.Error(w, "internal error", http.StatusInternalServerError)
			return
		}
	})
}

func newS3Putter(region, bucket string) filePutter {
	return func(key string, data io.ReadSeeker) error {
		sess, err := session.NewSession(&aws.Config{
			Region: aws.String(region)},
		)
		if err != nil {
			return err
		}
		svc := s3.New(sess)
		_, err = svc.PutObject(&s3.PutObjectInput{
			Bucket: aws.String(bucket),
			Key:    aws.String(key),
			Body:   data,
		})
		return err
	}
}

type filePutter func(key string, data io.ReadSeeker) error

func newWriter(put filePutter) func(length int, name string) error {
	return func(length int, name string) error {
		data := make([]byte, length)
		_, err := rand.Read(data)
		if err != nil {
			return fmt.Errorf("writeRandomFile: failed to create random data: %v", err)
		}
		err = put(name, bytes.NewReader(data))
		if err != nil {
			return fmt.Errorf("writeRandomFile: failed to write file %q: %v", name, err)
		}
		return nil
	}
}

With this design, you can swap out implementations for unit testing, by using an inline function, instead of needing to create structs and mocks, as per the unit test example.

The good news is that every programming language works in a similar way, so the process for testing something is the same: