Creating a unit test
In Go, it’s standard practice to write your tests in *_test.go files which live directly alongside the code that you’re testing. So, in this case, the first thing that we’re going to do is create a new cmd/web/template_test.go file to hold the test.
And then we can create a new unit test for the humanDate function like so:
package main import ( "testing" "time" ) func TestHumanDate(t *testing.T) { // Initialize a new time.Time object and pass it to the humanDate function. tm := time.Date(2024, 3, 17, 10, 15, 0, 0, time.UTC) hd := humanDate(tm) // Check that the output from the humanDate function is in the format we expect. If it isn't // what we expect, use the t.Errorf() function to indicate that the test has failed and log // the expected and actual values. if hd != "17 Mar 2024 at 10:15" { t.Errorf("got %q; want %q", hd, "17 Mar 2024 at 10:15") } }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v ./cmd/web === RUN TestHumanDate --- PASS: TestHumanDate (0.00s) PASS ok snippetbox/cmd/web 0.002s
Table-driven tests
Let’s now expand our TestHumanDate() function to cover some additional test cases. Specifically, we’re going to update it to also check that:
1. If the input to humanDate() is the zero time, then it returns the empty string "" .
2. The output from the humanDate() function always uses the UTC time zone.
In Go, an idiomatic way to run multiple test cases is to use table-driven tests.
Essentially, the idea behind table-driven tests is to create a ‘table’ of test cases containing the inputs and expected outputs, and to then loop over these, running each test case in a sub-test. There are a few ways you could set this up, but a common approach is to define your test cases in an slice of anonymous structs.
package main import ( "testing" "time" ) func TestHumanDate(t *testing.T) { // Create a slice of anonymous structs containing the test case name, input to our humanDate() // function, and expected output. tests := []struct{ name string input time.Time expected string }{ { name: "UTC", input: time.Date(2024, 3, 17, 10, 15, 0, 0, time.UTC), expected: "17 Mar 2024 at 10:15", }, { name: "Empty", input: time.Time{}, expected: "", }, { name: "CET", input: time.Date(2024, 3, 17, 10, 15, 0, 0, time.FixedZone("CET", 1*60*60)), expected: "17 Mar 2024 at 09:15", }, } // Loop over the test cases. for _, tc := range tests { // Use the t.Run() function to run a sub-test for each test case. The first parameter to // this is the name of the test (which is used to identify the sub-test in any log output) // and the second parameter is an anonymous function containing the actual test for each // case. t.Run(tc.name, func(t *testing.T) { hd := humanDate(tc.input) if hd != tc.expect { t.Errorf("got %q; expect %q", hd, tc.expected) } }) } }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v ./cmd/web === RUN TestHumanDate === RUN TestHumanDate/UTC === RUN TestHumanDate/Empty template_test.go:43: got "01 Jan 0001 at 00:00"; expect "" === RUN TestHumanDate/CET template_test.go:43: got "17 Mar 2024 at 10:15"; expect "17 Mar 2024 at 09:15" --- FAIL: TestHumanDate (0.00s) --- PASS: TestHumanDate/UTC (0.00s) --- FAIL: TestHumanDate/Empty (0.00s) --- FAIL: TestHumanDate/CET (0.00s) FAIL FAIL snippetbox/cmd/web 0.002s FAIL
So here we can see the individual output for each of our sub-tests. As you might have guessed, our first test case passed but the Empty and CET tests both failed. Notice how — for the failed test cases — we get the relevant failure message and filename and line number in the output?
Let’s head back to our humanDate() function and update it to fix these two problems:
func humanDate(t time.Time) string { if t.IsZero() { return "" } // Convert the time to UTC before formatting it. return t.UTC().Format("02 Jan 2006 at 15:04") }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v ./cmd/web === RUN TestHumanDate === RUN TestHumanDate/UTC === RUN TestHumanDate/Empty === RUN TestHumanDate/CET --- PASS: TestHumanDate (0.00s) --- PASS: TestHumanDate/UTC (0.00s) --- PASS: TestHumanDate/Empty (0.00s) --- PASS: TestHumanDate/CET (0.00s) PASS ok snippetbox/cmd/web 0.002s
Helpers for test assertions
As I mentioned briefly earlier, over the next few chapters we’ll be writing a lot of test assertions that are a variation of this pattern:
if actualValue != expectedValue { t.Errorf("got %v; want %v", actualValue, expectedValue) }
Let’s quickly abstract this code into a helper function.
If you’re following along, go ahead and create a new internal/assert package.
Create a new file named assert.go. And then add the following code:
package assert import "testing" func Equal[T comparable](t *testing.T, actual, expected T) { t.Helper() if actual != expected { t.Errorf("got: %v; expect: %v", actual, expected) } }
Note: The t.Helper() function that we’re using in the code above indicates to the Go test runner that our Equal() function is a test helper. This means that when t.Errorf() is called from our Equal() function, the Go test runner will report the filename and line number of the code which called our Equal() function in the output.
With that in place, we can simplify our TestHumanDate() test like so:
package main import ( "snippetbox/internal/assert" "testing" "time" ) func TestHumanDate(t *testing.T) { tests := []struct { name string input time.Time expected string }{ { name: "UTC", input: time.Date(2024, 3, 17, 10, 15, 0, 0, time.UTC), expected: "17 Mar 2024 at 10:15", }, { name: "Empty", input: time.Time{}, expected: "", }, { name: "CET", input: time.Date(2024, 3, 17, 10, 15, 0, 0, time.FixedZone("CET", 1*60*60)), expected: "17 Mar 2024 at 09:15", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { hd := humanDate(tc.input) // Use the new assert.Equal() helper to compare the actual and expected values. assert.Equal(t, hd, tc.expected) }) } }
Sub-tests without a table of test cases
It’s important to point out that you don’t need to use sub-tests in conjunction with table-driven tests (like we have done so far). It’s perfectly valid to execute sub-tests by calling t.Run() consecutively in your test functions, similar to this:
func TestExample(t *testing.T) { t.Run("Example sub-test 1", func(t *testing.T) { // Do a test. }) t.Run("Example sub-test 2", func(t *testing.T) { // Do another test. }) t.Run("Example sub-test 3", func(t *testing.T) { // And another... }) }
Testing HTTP handlers and middleware
Let’s move on and discuss some specific techniques for unit testing your HTTP handlers.
All the handlers that we’ve written for this project so far are a bit complex to test, and to introduce things I’d prefer to start off with something more simple.
So, if you’re following along, head over to your handlers.go file and create a new ping handler function which returns a 200 OK status code and an "OK" response body. It’s the type of handler that you might want to implement for status-checking or uptime monitoring of your server.
func ping(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) }
Recording responses
Go provides a bunch of useful tools in the net/http/httptest package for helping to test your HTTP handlers.
One of these tools is the httptest.ResponseRecorder type. This is essentially an implementation of http.ResponseWriter which records the response status code, headers and body instead of actually writing them to a HTTP connection.
So an easy way to unit test your handlers is to create a new httptest.ResponseRecorder , pass it to the handler function, and then examine it again after the handler returns.
First, follow the Go conventions and create a new handlers_test.go file to hold the test.
package main import ( "bytes" "io" "net/http" "net/http/httptest" "snippetbox/internal/assert" "testing" ) func TestPing(t *testing.T) { // Initialize a new httptest.ResponseRecorder. rr := httptest.NewRecorder() // Initialize a new dummy http.Request. r, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { t.Fatal(err) } // Call the ping handler function, passing in the httptest.ResponseRecorder and http.Request. ping(rr, r) // Call the Result() method on the http.ResponseRecorder to get the http.Response generated // by the ping handler. res := rr.Result() // Check that the status code written by the ping handler was 200. assert.Equal(t, res.StatusCode, http.StatusOK) // And we can check that the response body written by the ping handler equals "OK". defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } body = bytes.TrimSpace(body) assert.Equal(t, string(body), "OK") }
Note: In the code above we use the t.Fatal() function in a couple of places to handle situations where there is an unexpected error in our test code. When called, t.Fatal() will mark the test as failed, log the error, and then completely stop execution of the current test (or sub-test). Typically you should call t.Fatal() in situations where it doesn’t make sense to continue the current test — such as an error during a setup step, or where an unexpected error from a Go standard library function means you can’t proceed with the test.
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v ./cmd/web === RUN TestPing --- PASS: TestPing (0.00s) === RUN TestHumanDate === RUN TestHumanDate/UTC === RUN TestHumanDate/Empty === RUN TestHumanDate/CET --- PASS: TestHumanDate (0.00s) --- PASS: TestHumanDate/UTC (0.00s) --- PASS: TestHumanDate/Empty (0.00s) --- PASS: TestHumanDate/CET (0.00s) PASS ok snippetbox/cmd/web 0.002s
Testing middleware
First you’ll need to create a cmd/web/middleware_test.go file to hold the test.
package main import ( "bytes" "io" "net/http" "net/http/httptest" "snippetbox/internal/assert" "testing" ) func TestCommonHeaders(t *testing.T) { rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, "/", nil) if err != nil { t.Fatal(err) } // Create a mock HTTP handler that we can pass to our commonHeaders middleware, chich writes // a 200 status code and "OK" response body. next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) }) // Pass the mock HTTP handler to our commonHeaders middleware. Because commonHeaders *returns* // an http.Handler we can call its ServeHTTP() method, passing in the http.ResponseRecorder // and dummy http.Request to execute it. commonHeaders(next).ServeHTTP(rr, r) // Call the Result() method on the http.ResponseRecorder to get the results of the test. res := rr.Result() // Check that the middleware has correctly set the Content-Security-Policy header on the // response. expectedValue := "default-src 'self'; style-src 'self' fonts.googleapis.com; font-src fonts.gstatic.com" assert.Equal(t, res.Header.Get("Content-Security-Policy"), expectedValue) // Check that the middleware has correctly set the Referrer-Policy header on the response. expectedValue = "origin-when-cross-origin" assert.Equal(t, res.Header.Get("Referrer-Policy"), expectedValue) // Check that the middleware has correctly set the X-Content-Type-Options header on the // response. expectedValue = "nosniff" assert.Equal(t, res.Header.Get("X-Content-Type-Options"), expectedValue) // Check that the middleware has correctly set the X-Frame-Options header on the response. expectedValue = "deny" assert.Equal(t, res.Header.Get("X-Frame-Options"), expectedValue) // Check that the middleware has correctly set the X-XSS-Protection header on the response. expectedValue = "0" assert.Equal(t, res.Header.Get("X-XSS-Protection"), expectedValue) // Check that the middleware has correctly set the Server header on the response. expectedValue = "Go" assert.Equal(t, res.Header.Get("Server"), expectedValue) // Check that the middleware has correctly called the next handler in line and the response // status code and body are expected. assert.Equal(t, res.StatusCode, http.StatusOK) defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } body = bytes.TrimSpace(body) assert.Equal(t, string(body), "OK") }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v ./cmd/web === RUN TestPing --- PASS: TestPing (0.00s) === RUN TestCommonHeaders --- PASS: TestCommonHeaders (0.00s) === RUN TestHumanDate === RUN TestHumanDate/UTC === RUN TestHumanDate/Empty === RUN TestHumanDate/CET --- PASS: TestHumanDate (0.00s) --- PASS: TestHumanDate/UTC (0.00s) --- PASS: TestHumanDate/Empty (0.00s) --- PASS: TestHumanDate/CET (0.00s) PASS ok snippetbox/cmd/web 0.002s
So, in summary, a quick and easy way to unit test your HTTP handlers and middleware is to simply call them using the httptest.ResponseRecorder type. You can then examine the status code, headers and response body of the recorded response to make sure that they are working as expected.
End-to-end testing
We’re going to explain how to run end-to-end tests on your web application that encompass your routing, middleware and handlers. In most cases, end-to-end testing should give you more confidence that your application is working correctly than unit testing in isolation.
To illustrate this, we’ll adapt our TestPing function so that it runs an end-to-end test on our code. Specifically, we want the test to ensure that a GET /ping request to our application calls the ping handler function and results in a 200 OK status code and "OK" response body.
Essentially, we want to test that our application has a route like this:
Using httptest.Server
The key to end-to-end testing our application is the httptest.NewTLSServer() function, which spins up a httptest.Server instance that we can make HTTPS requests to.
The whole pattern is a bit too complicated to explain upfront, so it’s probably best to demonstrate first by writing the code and then we’ll talk through the details afterwards.
With that in mind, head back to your handlers_test.go file and update the TestPing test so that it looks like this:
package main import ( "bytes" "io" "log/slog" "net/http" "net/http/httptest" "snippetbox/internal/assert" "testing" ) func TestPing(t *testing.T) { // Create a new instance of our application struct. For now, this just contains a structured // logger (which discards anything writtent to it). app := &application{ logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } // We then use the httptest.NewTLSServer() function to create a new test server, passing in // the value returned by our app.routes() method as the handler for the server. This starts up // an HTTPS server which listens on a randomly-chosen port of your local machine for the // duration of the test. Notice that we defer a call to ts.Close() so that the server is // shutdown when the test finishes. ts := httptest.NewTLSServer(app.routes()) defer ts.Close() // The network address that the test server is listening on is contained in the ts.URL field. // We can use this along with the ts.Client().Get() method to make a GET /ping request against // the test server. This returns an http.Response struct containing the responses. res, err := ts.Client().Get(ts.URL + "/ping") if err != nil { t.Fatal(err) } // We can then check the value of the response status code and body. assert.Equal(t, res.StatusCode, http.StatusOK) defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } body = bytes.TrimSpace(body) assert.Equal(t, string(body), "OK") }
There are a few things about this code to point out and discuss.
- When we call httptest.NewTLSServer() to initialize the test server we need to pass in a http.Handler as the parameter — and this handler is called each time the test server receives a HTTPS request. In our case, we’ve passed in the return value from our app.routes() method, meaning that a request to the test server will use all our real application routes, middleware and handlers. This is a big upside of the work that we did earlier in the book to isolate all our application routing in the app.routes() method.
- If you’re testing a HTTP (not HTTPS) server you should use the httptest.NewServer() function to create the test server instead.
- The ts.Client() method returns the test server client — which has the type http.Client — and we should always use this client to send requests to the test server. It’s possible to configure the client to tweak its behavior, and we’ll explain how to do that at the end of this chapter.
- You might be wondering why we have set the logger field of our application struct, but none of the other fields. The reason for this is that the logger is needed by the logRequest and recoverPanic middlewares, which are used by our application on every route. Trying to run this test without setting these the two dependencies will result in a panic.
Anyway, let’s try out the new test:
zzh@ZZHPC:/zdata/Github/snippetbox$ go test ./cmd/web --- FAIL: TestPing (0.00s) handlers_test.go:37: got: 404; expect: 200 handlers_test.go:46: got: 404 page not found; expect: OK FAIL FAIL snippetbox/cmd/web 0.006s FAIL
We can see from the test output that the response from our GET /ping request has a 404 status code, rather than the 200 we expected. And that’s because we haven’t actually registered a GET /ping route with our router yet.
Let’s fix that now:
package main import ( "net/http" "snippetbox/ui" "github.com/justinas/alice" ) func (app *application) routes() http.Handler { mux := http.NewServeMux() mux.Handle("GET /static/", http.FileServerFS(ui.Files)) mux.HandleFunc("GET /ping", ping) // Unprotected routes using the "dynamic" middleware chain. dynamic := alice.New(app.sessionManager.LoadAndSave, noSurf, app.authenticate) mux.Handle("GET /{$}", dynamic.ThenFunc(app.home)) mux.Handle("GET /snippet/view/{id}", dynamic.ThenFunc(app.snippetView)) mux.Handle("GET /user/signup", dynamic.ThenFunc(app.userSignup)) mux.Handle("POST /user/signup", dynamic.ThenFunc(app.userSignupPost)) mux.Handle("GET /user/login", dynamic.ThenFunc(app.userLogin)) mux.Handle("POST /user/login", dynamic.ThenFunc(app.userLoginPost)) // Protected (authenticated-only) routes using the "protected" middleware chain which includes // the requireAuthentication middleware. protected := dynamic.Append(app.requireAuthentication) mux.Handle("GET /snippet/create", protected.ThenFunc(app.snippetCreate)) mux.Handle("POST /snippet/create", protected.ThenFunc(app.snippetCreatePost)) mux.Handle("POST /user/logout", protected.ThenFunc(app.userLogoutPost)) standard := alice.New(app.recoverPanic, app.logRequest, commonHeaders) return standard.Then(mux) }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test ./cmd/web ok snippetbox/cmd/web 0.006s
Using test helpers
Our TestPing test is now working nicely. But there’s a good opportunity to break out some of this code into helper functions, which we can reuse as we add more end-to-end tests to our project.
There’s no hard-and-fast rules about where to put helper methods for tests. If a helper is only used in a specific *_test.go file, then it probably makes sense to include it inline in that file alongside your tests. At the other end of the spectrum, if you are going to use a helper in tests across multiple packages, then you might want to put it in a reusable package called internal/testutils (or similar) which can be imported by your test files.
In our case, the helpers will be used for testing code throughout our cmd/web package but nowhere else, so it seems reasonable to put them in a new cmd/web/testhelpers_test.go file.
package main import ( "bytes" "io" "log/slog" "net/http" "net/http/httptest" "testing" ) // This helper returns an instance of our application struct containing mocked dependencies. func newTestApplication(t *testing.T) *application { return &application{ logger: slog.New(slog.NewTextHandler(io.Discard, nil)), } } // Define a custom testServer type which embeds an httptest.Server instance. type testServer struct { *httptest.Server } // This helper initializes and returns a new instance of our custom testServer type. func newTestServer(t *testing.T, h http.Handler) *testServer { ts := httptest.NewTLSServer(h) return &testServer{ts} } // Implement a get() method on our custom testServer type. This makes a GET request to a given url // path using the test server client, and returns the response status code, headers and body. func (ts *testServer) get(t *testing.T, urlPath string) (int, http.Header, string) { res, err := ts.Client().Get(ts.URL + urlPath) if err != nil { t.Fatal(err) } defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } body = bytes.TrimSpace(body) return res.StatusCode, res.Header, string(body) }
Let’s head back to our TestPing handler and put these new helpers to work:
package main import ( "net/http" "snippetbox/internal/assert" "testing" ) func TestPing(t *testing.T) { app := newTestApplication(t) ts := newTestServer(t, app.routes()) defer ts.Close() code, _, body := ts.get(t, "/ping") assert.Equal(t, code, http.StatusOK) assert.Equal(t, body, "OK") }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test ./cmd/web ok snippetbox/cmd/web 0.006s
This is shaping up nicely now. We have a neat pattern in place for spinning up a test server and making requests to it, encompassing our routing, middleware and handlers in an end-to-end test. We’ve also broken apart some of the code into helpers, which will make writing future tests quicker and easier.
Cookies and redirections
So far in this chapter we’ve been using the default test server client settings. But there are a couple of changes I’d like to make so that it’s better suited to testing our web application. Specifically:
- We want the client to automatically store any cookies sent in a HTTPS response, so that we can include them (if appropriate) in any subsequent requests back to the test server. This will come in handy later in the book when we need cookies to be supported across multiple requests in order to test our anti-CSRF measures.
- We don’t want the client to automatically follow redirects. Instead we want it to return the first HTTPS response sent by our server so that we can test the response for that specific request.
To make these changes, let’s go back to the testhelpers_test.go file we just created and update the newTestServer() function like so:
package main import ( "bytes" "io" "log/slog" "net/http" "net/http/cookiejar" "net/http/httptest" "testing" ) ...func newTestServer(t *testing.T, h http.Handler) *testServer { ts := httptest.NewTLSServer(h) // Initialize a new cookie jar. jar, err := cookiejar.New(nil) if err != nil { t.Fatal(err) } // Add the cookie jar to the test server client. Any response cookies will now be stored and // sent with subsequent requests when using this client. ts.Client().Jar = jar // Disable redirect-following for the test server client by setting a custom CheckRedirect // function. This function will be called whenever a 3xx response is received by the client, // and always returning an http.ErrUseLastResponse error, which forces the client to // immediately return the received response. ts.Client().CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } return &testServer{ts} } ...
Controlling which tests are run
Use the -run flag. This allows you to specify a regular expression — and only tests with a name that matches the regular expression will be run.
Test caching
If you want force your tests to run in full (and avoid the cache) you can use the -count=1 flag.
Fast failure
If you would prefer to terminate the tests immediately after the first failure you can use the -failfast flag.
It’s important to note that the -failfast flag only stops tests in the package that had the failure. If you are running tests in multiple packages (for example by using go test ./... ), then the tests in the other packages will continue to run.
Parallel testing
By default, the go test command executes all tests in a serial manner, one after another. When you have a small number of tests (like we do) and the runtime is very fast, this is absolutely fine.
But if you have hundreds or thousands of tests the total run time can start adding up to something more meaningful. And in that scenario, you may save yourself some time by running the tests in parallel.
You can indicate that it’s OK for a test to be run concurrently alongside other tests by calling the t.Parallel() function at the start of the test. For example:
func TestPing(t *testing.T) { t.Parallel() ... }
It’s important to note here that:
- Tests marked using t.Parallel() will be run in parallel with — and only with — other parallel tests.
- By default, the maximum number of tests that will be run simultaneously is the current value of GOMAXPROCS. You can override this by setting a specific value via the -parallel flag. For example:
$ go test -parallel=4 ./...
- Not all tests are suitable to be run in parallel. For example, if you have an integration test which requires a database table to be in a specific known state, then you wouldn’t want to run it in parallel with other tests that manipulate the same database table.
Enabling the race detector
The go test command includes a -race flag which enables Go’s race detector when running tests.
If the code you’re testing leverages concurrency, or you’re running tests in parallel, enabling this can be a good idea to help to flag up race conditions that exist in your application. You can use it like so:
$ go test -race ./cmd/web/
You should be aware that the race detector is limited in its usefulness… it’s just a tool that flags data races if and when they are identified at runtime during testing. It doesn’t carry out static analysis of your codebase, and a clear run doesn’t ensure that your code is free of race conditions.
Enabling the race detector will also increase the overall running time of your tests. So if you’re running tests very frequently as part of a TDD workflow, you may prefer to use the -race flag during pre-commit test runs only.
Mocking dependencies
Throughout this project we’ve injected dependencies into our handlers via the application struct.
When testing, it sometimes makes sense to mock these dependencies instead of using exactly the same ones that you do in your production application.
For example, in the previous chapter we mocked the logger dependency with a logger that write messages to io.Discard , instead of the os.Stdout and stream like we do in our production application.
The reason for mocking this and writing to io.Discard is to avoid clogging up our test output with unnecessary log messages when we run go test -v (with verbose mode enabled).
Note: Depending on your background and programming experience, you might not consider this logger to a mock. You might call it a fake, stub or something else entirely. But the name doesn’t really matter — and different people call them different things. What’s important is that we’re using something which exposes the same interface as a production object for the purpose of testing.
The other two dependencies that it makes sense for us to mock are the models.SnippetModel and models.UserModel database models. By creating mocks of these it’s possible for us to test the behavior of our handlers without needing to setup an entire test instance of the MySQL database.
Mocking the database models
If you’re following along, create a new internal/models/mocks package containing snippet.go and user.go files to hold the database model mocks.
package mocks import ( "snippetbox/internal/models" "time" ) var mockSnippet = models.Snippet{ ID: 1, Title: "An old silent pond", Content: "An old silent pond...", Created: time.Now(), Expires: time.Now(), } type SnippetModel struct{} func (m *SnippetModel) Insert(title string, content string, expires int) (int, error) { return 2, nil } func (m *SnippetModel) Get(id int) (models.Snippet, error) { switch id { case 1: return mockSnippet, nil default: return models.Snippet{}, models.ErrNoRecord } } func (m *SnippetModel) Latest() ([]models.Snippet, error) { return []models.Snippet{mockSnippet}, nil }
package mocks import "snippetbox/internal/models" type UserModel struct{} func (m *UserModel) Insert(name, email, password string) error { switch email { case "dupe@example.com": return models.ErrDuplicateEmail default: return nil } } func (m *UserModel) Authenticate(email, password string) (int, error) { if email == "alice@example.com" && password == "pa$$word" { return 1, nil } return 0, models.ErrInvalidCredentials } func (m *UserModel) Exists(id int) (bool, error) { switch id { case 1: return true, nil default: return false, nil } }
Initializing the mocks
For the next step in our build, let’s head back to the testhelpers_test.go file and update the newTestApplication() function so that it creates an application struct with all the necessary dependencies for testing.
This is happening because our application struct is expecting pointers to models.SnippetModel and models.UserModel instances, but we are trying to use pointers to mocks.SnippetModel and mocks.UserModel instances instead.
The idiomatic fix for this is to change our application struct so that it uses interfaces which are satisfied by both our mock and production database models.
To do this, let’s head back to our internal/models/snippet.go file and create a new SnippetModelInterface interface type that describes the methods that our actual SnippetModel struct has.
type SnippetModelInterface interface { Insert(title string, content string, expires int) (int, error) Get(id int) (Snippet, error) Latest() ([]Snippet, error) }
And let’s also do the same thing for our UserModel struct too:
type UserModelInterface interface { Insert(name, email, password string) error Authenticate(email, password string) (int, error) Exists(id int) (bool, error) }
Now that we’ve defined those interface types, let’s update our application struct to use them instead of the concrete SnippetModel and UserModel types. Like so:
type application struct { logger *slog.Logger snippet models.SnippetModelInterface // Use our new interface type. user models.UserModelInterface // Use our new interface type. templateCache map[string]*template.Template formDecoder *form.Decoder sessionManager *scs.SessionManager }
And if you try running the tests again now, everything should work correctly.
Testing the snippetView handler
With that all now set up, let’s get stuck into writing an end-to-end test for our snippetView handler which uses these mocked dependencies.
As part of this test, the code in our snippetView handler will call the mocks.SnippetModel.Get() method. Just to remind you, this mocked model method returns a models.ErrNoRecord unless the snippet ID is 1 — when it will return the following mock
snippet:
var mockSnippet = models.Snippet{ ID: 1, Title: "An old silent pond", Content: "An old silent pond...", Created: time.Now(), Expires: time.Now(), }
So specifically, we want to test that:
1. For the request GET /snippet/view/1 we receive a 200 OK response with the relevant mocked snippet contained in the HTML response body.
2. For all other requests to GET /snippet/view/* we should receive a 404 Not Found response.
For the first part here, we want to check that the request body contains some specific content, rather than being exactly equal to it. Let’s quickly add a new StringContains() function to our assert package to help with that:
func StringContains(t *testing.T, actual, expectedSubstring string) { t.Helper() if !strings.Contains(actual, expectedSubstring) { t.Errorf("got: %q; expected to contain: %q", actual, expectedSubstring) } }
And then open up the cmd/web/handlers_test.go file and create a new TestSnippetView test like so:
func TestSnippetView(t *testing.T) { // Create a new instance of our application struct which uses the mocked dependencies. app := newTestApplication(t) // Establish a new test server for running end-to-end tests. ts := newTestServer(t, app.routes()) defer ts.Close() tests := []struct { name string urlPath string expectCode int expectBody string }{ { name: "Valid ID", urlPath: "/snippet/view/1", expectCode: http.StatusOK, expectBody: "An old silent pond...", }, { name: "Non-existent ID", urlPath: "/snippet/view/2", expectCode: http.StatusNotFound, }, { name: "Negative ID", urlPath: "/snippet/view/-1", expectCode: http.StatusNotFound, }, { name: "Decimal ID", urlPath: "/snippet/view/1.23", expectCode: http.StatusNotFound, }, { name: "String ID", urlPath: "/snippet/view/foo", expectCode: http.StatusNotFound, }, { name: "Empty ID", urlPath: "/snippet/view/", expectCode: http.StatusNotFound, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { code, _, body := ts.get(t, tc.urlPath) assert.Equal(t, code, tc.expectCode) if tc.expectBody != "" { assert.StringContains(t, body, tc.expectBody) } }) } }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v ./cmd/web === RUN TestPing --- PASS: TestPing (0.00s) === RUN TestSnippetView === RUN TestSnippetView/Valid_ID === RUN TestSnippetView/Non-existent_ID === RUN TestSnippetView/Negative_ID === RUN TestSnippetView/Decimal_ID === RUN TestSnippetView/String_ID === RUN TestSnippetView/Empty_ID --- PASS: TestSnippetView (0.00s) --- PASS: TestSnippetView/Valid_ID (0.00s) --- PASS: TestSnippetView/Non-existent_ID (0.00s) --- PASS: TestSnippetView/Negative_ID (0.00s) --- PASS: TestSnippetView/Decimal_ID (0.00s) --- PASS: TestSnippetView/String_ID (0.00s) --- PASS: TestSnippetView/Empty_ID (0.00s) === RUN TestCommonHeaders --- PASS: TestCommonHeaders (0.00s) === RUN TestHumanDate === RUN TestHumanDate/UTC === RUN TestHumanDate/Empty === RUN TestHumanDate/CET --- PASS: TestHumanDate (0.00s) --- PASS: TestHumanDate/UTC (0.00s) --- PASS: TestHumanDate/Empty (0.00s) --- PASS: TestHumanDate/CET (0.00s) PASS ok snippetbox/cmd/web 0.010s
As an aside, notice how the names of the sub-tests have been canonicalized? Go automatically replaces any spaces in the sub-test name with an underscore (and any non-printable characters will also be escaped) in the test output.
Testing HTML forms
We’re going to add an end-to-end test for the POST /user/signup route, which is handled by our userSignupPost handler.
Testing this route is made a bit more complicated by the anti-CSRF check that our application does. Any request that we make to POST /user/signup will always receive a 400 Bad Request response unless the request contains a valid CSRF token and cookie. To get around this we need to simulate the workflow of a real-life user as part of our test, like so:
1. Make a GET /user/signup request. This will return a response which contains a CSRF cookie in the response headers and the CSRF token for the signup page in the response body.
2. Extract the CSRF token from the HTML response body.
3. Make a POST /user/signup request, using the same http.Client that we used in step 1 (so it automatically passes the CSRF cookie with the POST request) and including the CSRF token alongside the other POST data that we want to test.
Let’s begin by adding a new helper function to our cmd/web/testhelpers_test.go file for extracting the CSRF token (if one exists) from an HTML response body:
// Define a regular expression which captures the CSRF token value from the HTML for our user // signup page. var csrfTokenRX = regexp.MustCompile(`<input type="hidden" name="csrf_token" value="(.+)">`) func extractCSRFToken(t *testing.T, body string) string { // Use the FindStringSubmatch method to extract the token from the HTML body. Note that this // returns an array with the entire matched pattern in the first position, and the values of // any captured data in the subsequent positions. matches := csrfTokenRX.FindStringSubmatch(body) if len(matches) < 2 { t.Fatal("no csrf token found in body") } return html.UnescapeString(matches[1]) }
Note: You might be wondering why we are using the html.UnescapeString() function before returning the CSRF token. The reason for this is because Go’s html/template package automatically escapes all dynamically rendered data… including our CSRF token. Because the CSRF token is a base64 encoded string it will potentially include the + character, and this will be escaped to + . So after extracting the token from the HTML we need to run it through html.UnescapeString() to get the original token value.
Now that’s in place, let’s go back to our cmd/web/handlers_test.go file and create a new TestUserSignup test.
To start with, we’ll make this perform a GET /user/signup request and then extract and print out the CSRF token from the HTML response body. Like so:
func TestUserSignup(t *testing.T) { app := newTestApplication(t) ts := newTestServer(t, app.routes()) defer ts.Close() // Make a GET /user/signup request and then extract the CSRF token from the response body. _, _, body := ts.get(t, "/user/signup") csrfToken := extractCSRFToken(t, body) // Log the CSRF token value in our test output using the t.Logf() function. The t.Logf() // function works in the same way as fmt.Printf(), but writes the provided message to the test // output. t.Logf("CSRF token is: %q", csrfToken) }
Importantly, you must run tests using the -v flag (to enable verbose output) in order to see any output from the t.Logf() function.
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v -run="TestUserSignup" ./cmd/web === RUN TestUserSignup handlers_test.go:93: CSRF token is: "M8TqH18rZrbECVrindSLpn0dFLvan66Ikf2XVditLq9anJHjoKTDbyLc1TK6quh82vhIN/irsB8tXIfw/P+9bw==" --- PASS: TestUserSignup (0.00s) PASS ok snippetbox/cmd/web 0.007s zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v -run="TestUserSignup" ./cmd/web === RUN TestUserSignup handlers_test.go:93: CSRF token is: "M8TqH18rZrbECVrindSLpn0dFLvan66Ikf2XVditLq9anJHjoKTDbyLc1TK6quh82vhIN/irsB8tXIfw/P+9bw==" --- PASS: TestUserSignup (0.00s) PASS ok snippetbox/cmd/web (cached) zzh@ZZHPC:/zdata/Github/snippetbox$ go test -count=1 -v -run="TestUserSignup" ./cmd/web === RUN TestUserSignup handlers_test.go:93: CSRF token is: "dnLWaX8cNZ4WT4p8wLIcC8GGHgIBMFG9zPEEdIObkJSyDUBc+g4vITXjEjSh1Zv4HNFL6fAIk6AKNh9jSREijg==" --- PASS: TestUserSignup (0.00s) PASS ok snippetbox/cmd/web 0.006s
Note: If you run this test for a second time immediately afterwards, without changing anything in the cmd/web package, you’ll get the same CSRF token in the test output because the test results have been cached.
Testing post requests
Now let’s head back to our cmd/web/testhelpers_test.go file and create a new postForm() method on our testServer type, which we can use to send a POST request to our test server with specific form data in the request body.
// Create a postForm method for sending POST requests to the server. The final parameter to this // method is an url.Values object which can contain any form data that you want to send in the // request body. func (ts *testServer) postForm(t *testing.T, urlPath string, form url.Values) (int, http.Header, string) { res, err := ts.Client().PostForm(ts.URL + urlPath, form) if err != nil { t.Fatal(err) } defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } body = bytes.TrimSpace(body) return res.StatusCode, res.Header, string(body) }
And now, at last, we’re ready to add some table-driven sub-tests to test the behavior of our application’s POST /user/signup route. Specifically, we want to test that:
- A valid signup results in a 303 See Other response.
- A form submission without a valid CSRF token results in a 400 Bad Request response.
- A invalid form submission results in a 422 Unprocessable Entity response and thesignup form is redisplayed. This should happen when:
-
- The name, email or password fields are empty.
- The email is not in a valid format.
- The password is less than 8 characters long.
- The email address is already in use.
Go ahead and update the TestUserSignup function to carry out these tests like so:
func TestUserSignup(t *testing.T) { app := newTestApplication(t) ts := newTestServer(t, app.routes()) defer ts.Close() _, _, body := ts.get(t, "/user/signup") validCSRFToken := extractCSRFToken(t, body) const ( validName = "Bob" validPassword = "validPa$$word" validEmail = "bob@example.com" formTag = `<form action="/user/signup" method="POST" novalidate>` ) tests := []struct { name string userName string userEmail string userPassword string csrfToken string expectCode int expectFormTag string }{ { name: "Valid submission", userName: validName, userEmail: validEmail, userPassword: validPassword, csrfToken: validCSRFToken, expectCode: http.StatusSeeOther, }, { name: "Invalid CSRF Token", userName: validName, userEmail: validEmail, userPassword: validPassword, csrfToken: "wrongToken", expectCode: http.StatusBadRequest, }, { name: "Empty name", userName: "", userEmail: validEmail, userPassword: validPassword, csrfToken: validCSRFToken, expectCode: http.StatusUnprocessableEntity, expectFormTag: formTag, }, { name: "Empty email", userName: validName, userEmail: "", userPassword: validPassword, csrfToken: validCSRFToken, expectCode: http.StatusUnprocessableEntity, expectFormTag: formTag, }, { name: "Empty password", userName: validName, userEmail: validEmail, userPassword: "", csrfToken: validCSRFToken, expectCode: http.StatusUnprocessableEntity, expectFormTag: formTag, }, { name: "Invalid email", userName: validName, userEmail: "bob@example.", userPassword: validPassword, csrfToken: validCSRFToken, expectCode: http.StatusUnprocessableEntity, expectFormTag: formTag, }, { name: "Short password", userName: validName, userEmail: validEmail, userPassword: "pa$$", csrfToken: validCSRFToken, expectCode: http.StatusUnprocessableEntity, expectFormTag: formTag, }, { name: "Duplicate email", userName: validName, userEmail: "dupe@example.com", userPassword: validPassword, csrfToken: validCSRFToken, expectCode: http.StatusUnprocessableEntity, expectFormTag: formTag, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { form := url.Values{} form.Add("name", tc.userName) form.Add("email", tc.userEmail) form.Add("password", tc.userPassword) form.Add("csrf_token", tc.csrfToken) code, _, body := ts.postForm(t, "/user/signup", form) assert.Equal(t, code, tc.expectCode) if tc.expectFormTag != "" { assert.StringContains(t, body, tc.expectFormTag) } }) } }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v ./cmd/web === RUN TestPing --- PASS: TestPing (0.00s) === RUN TestSnippetView === RUN TestSnippetView/Valid_ID === RUN TestSnippetView/Non-existent_ID === RUN TestSnippetView/Negative_ID === RUN TestSnippetView/Decimal_ID === RUN TestSnippetView/String_ID === RUN TestSnippetView/Empty_ID --- PASS: TestSnippetView (0.00s) --- PASS: TestSnippetView/Valid_ID (0.00s) --- PASS: TestSnippetView/Non-existent_ID (0.00s) --- PASS: TestSnippetView/Negative_ID (0.00s) --- PASS: TestSnippetView/Decimal_ID (0.00s) --- PASS: TestSnippetView/String_ID (0.00s) --- PASS: TestSnippetView/Empty_ID (0.00s) === RUN TestUserSignup === RUN TestUserSignup/Valid_submission === RUN TestUserSignup/Invalid_CSRF_Token === RUN TestUserSignup/Empty_name === RUN TestUserSignup/Empty_email === RUN TestUserSignup/Empty_password === RUN TestUserSignup/Invalid_email === RUN TestUserSignup/Short_password === RUN TestUserSignup/Duplicate_email --- PASS: TestUserSignup (0.01s) --- PASS: TestUserSignup/Valid_submission (0.00s) --- PASS: TestUserSignup/Invalid_CSRF_Token (0.00s) --- PASS: TestUserSignup/Empty_name (0.00s) --- PASS: TestUserSignup/Empty_email (0.00s) --- PASS: TestUserSignup/Empty_password (0.00s) --- PASS: TestUserSignup/Invalid_email (0.00s) --- PASS: TestUserSignup/Short_password (0.00s) --- PASS: TestUserSignup/Duplicate_email (0.00s) === RUN TestCommonHeaders --- PASS: TestCommonHeaders (0.00s) === RUN TestHumanDate === RUN TestHumanDate/UTC === RUN TestHumanDate/Empty === RUN TestHumanDate/CET --- PASS: TestHumanDate (0.00s) --- PASS: TestHumanDate/UTC (0.00s) --- PASS: TestHumanDate/Empty (0.00s) --- PASS: TestHumanDate/CET (0.00s) PASS ok snippetbox/cmd/web 0.018s
Integration testing
Running end-to-end tests with mocked dependencies is a good thing to do, but we could improve confidence in our application even more if we also verify that our real MySQL database models are working as expected.
To do this we can run integration tests against a test version our MySQL database, which mimics our production database but exists for testing purposes only.
As a demonstration, we’ll setup an integration test to ensure that our models.UserModel.Exists() method is working correctly.
Test database setup and teardown
The first step is to create the test version of our MySQL database.
If you’re following along, connect to MySQL from your terminal window as the root user and execute the following SQL statements to create a new test_snippetbox database and test_web user:
CREATE DATABASE test_snippetbox CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
CREATE USER test_web; GRANT CREATE, DROP, ALTER, INDEX, SELECT, INSERT, UPDATE, DELETE ON test_snippetbox.* TO test_web; ALTER USER test_web IDENTIFIED BY 'test';
Once that’s done, let’s make two SQL scripts:
1. A setup script to create the database tables (so that they mimic our production database) and insert a known set of test data than we can work with in our tests.
2. A teardown script which drops the database tables and data.
The idea is that we’ll call these scripts at the start and end of each integration test, so that the test database is fully reset each time. This helps ensure that any changes we make during one test are not ‘leaking’ and affecting the results of another test.
Let’s go ahead and create these scripts in a new internal/models/testdata directory like so:
File internal/models/testdata/setup.sql:
CREATE TABLE snippet ( id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, title VARCHAR(100) NOT NULL, content TEXT NOT NULL, created DATETIME NOT NULL, expires DATETIME NOT NULL ); CREATE INDEX idx_snippet_created ON snippet(created); CREATE TABLE user ( id INTEGER NOT NULL PRIMARY KEY AUTO_INCREMENT, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, hashed_password CHAR(60) NOT NULL, created DATETIME NOT NULL ); ALTER TABLE user ADD CONSTRAINT uc_user_email UNIQUE (email); INSERT INTO user (name, email, hashed_password, created) VALUES ( 'Alice Jones', 'alice@example.com', '$2a$12$NuTjWXm3KKntReFwyBVHyuf/to.HEwTy.eS206TNfkGfr6HzGJSWG', '2022-01-01 09:18:24' ); COMMIT;
File internal/models/testdata/teardown.sql:
DROP TABLE user; DROP TABLE snippet;
Note: The Go tool ignores any directories called testdata , so these scripts will be ignored when compiling your application. As an aside, it also ignores any directories or files which have names that begin with an _ or . character too.
Alright, now that we’ve got the scripts in place, let’s make a new file internal/models/testhelpers_test.go to hold some helper functions for our integration tests:
In this file let’s create a newTestDB() helper function which:
- Creates a new *sql.DB connection pool for the test database;
- Executes the setup.sql script to create the database tables and dummy data;
- Register a ‘cleanup’ function which executes the teardown.sql script and closes the connection pool.
package models import ( "database/sql" "os" "testing" ) func newTestDB(t *testing.T) *sql.DB { // Establish a sql.DB connection pool for our test database. Because our setup and teardown // scripts contains multiple SQL statements, we need to use the "multiStatements=true" // parameter in our DSN. This instructs our MySQL database driver to support executing // multiple SQL statements in one db.Exec() call. db, err := sql.Open("mysql", "test_web:test@tcp(localhost:3306)/test_snippetbox?parseTime=true&multiStatements=true") if err != nil { t.Fatal(err) } // Read the setup SQL script from the file and execute the statements, closing the connection // pool and calling t.Fatal() in the event of an error. script, err := os.ReadFile("./testdata/setup.sql") if err != nil { db.Close() t.Fatal(err) } _, err = db.Exec(string(script)) if err != nil { db.Close() t.Fatal(err) } // Use t.Cleanup() to register a function *which will automatically be called by Go when the // current test (or sub-test) which calls newTestDB() has finished*. In this function we read // and execute the teardown script, and close the database connection pool. t.Cleanup(func() { defer db.Close() script, err := os.ReadFile("./testdata/teardown.sql") if err != nil { t.Fatal(err) } _, err = db.Exec(string(script)) if err != nil { t.Fatal(err) } }) // Return the database connection pool. return db }
The important thing to take away here is this:
Whenever we call this newTestDB() function inside a test (or sub-test) it will run the setup script against the test database. And when the test or sub-test finishes, the cleanup function will automatically be executed and the teardown script will be run.
Testing the UserModel.Exists method
Now that the preparatory work is done, we’re ready to actually write our integration test for the models.UserModel.Exists() method.
We know that our setup.sql script creates a user table containing one record (which should have the user ID 1 and email address alice@example.com ). So we want to test that:
- Calling models.UserModel.Exists(1) returns a true boolean value and a nil error value.
- Calling models.UserModel.Exists() with any other user ID returns a false boolean value and a nil error value.
Let’s first head to our internal/assert package and create a new NilError() assertion, which we will use to check that an error value is nil . Like so:
func NilError(t *testing.T, actual error) { t.Helper() if actual != nil { t.Errorf("got: %v; expected: nil", actual) } }
Let's create a file nternal/models/user_test.go and add a TestUserModelExists test containing the following code:
package models import ( "snippetbox/internal/assert" "testing" ) func TestUserModelExists(t *testing.T) { tests := []struct { name string userID int expect bool }{ { name: "Valid ID", userID: 1, expect: true, }, { name: "Zero ID", userID: 0, expect: false, }, { name: "Non-existent ID", userID: 2, expect: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Call the newTestDB() helper function to get a connection pool to our test database. // Calling this here -- inside t.Run() -- means that fresh database tables and data // will be setup and torn down for each sub-test. db := newTestDB(t) // Create a new instance of the UserModel. m := UserModel{db} // Call the UserModel.Exists() method and check that the return value and error match // the expected values for the sub-test. exists, err := m.Exists(tc.userID) assert.Equal(t, exists, tc.expect) assert.NilError(t, err) }) } }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v ./internal/models === RUN TestUserModelExists === RUN TestUserModelExists/Valid_ID === RUN TestUserModelExists/Zero_ID === RUN TestUserModelExists/Non-existent_ID --- PASS: TestUserModelExists (0.39s) --- PASS: TestUserModelExists/Valid_ID (0.10s) --- PASS: TestUserModelExists/Zero_ID (0.10s) --- PASS: TestUserModelExists/Non-existent_ID (0.20s) PASS ok snippetbox/internal/models 0.392s
The last line in the test output here is worth a mention. The total runtime for this test (1.023 seconds in my case) is much longer than for our previous tests — all of which took a few milliseconds to run. This big increase in runtime is primarily due to the large number of database operations that we needed to make during the tests.
While 1 second is a totally acceptable time to wait for this test in isolation, if you’re running hundreds of different integration tests against your database you might end up routinely waiting minutes — rather than seconds — for your tests to finish.
Skipping long-running tests
When your tests take a long time, you might decide that you want to skip specific long-running tests under certain circumstances. For example, you might decide to only run your integration tests before committing a change, instead of more frequently during development.
A common and idiomatic way to skip long-running tests is to use the testing.Short() function to check for the presence of a -short flag in your go test command, and then call the t.Skip() method to skip the test if the flag is present.
Let’s quickly update TestUserModelExists to do this before running its actual tests, like so:
func TestUserModelExists(t *testing.T) { // Skip the test if the "-short" flag is provided when running the test. if testing.Short() { t.Skip("models: skipping integration test") } ... }
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -v -short ./... ? snippetbox/internal/assert [no test files] ? snippetbox/internal/models/mocks [no test files] ? snippetbox/internal/validator [no test files] ? snippetbox/ui [no test files] === RUN TestPing --- PASS: TestPing (0.00s) === RUN TestSnippetView === RUN TestSnippetView/Valid_ID === RUN TestSnippetView/Non-existent_ID === RUN TestSnippetView/Negative_ID === RUN TestSnippetView/Decimal_ID === RUN TestSnippetView/String_ID === RUN TestSnippetView/Empty_ID --- PASS: TestSnippetView (0.00s) --- PASS: TestSnippetView/Valid_ID (0.00s) --- PASS: TestSnippetView/Non-existent_ID (0.00s) --- PASS: TestSnippetView/Negative_ID (0.00s) --- PASS: TestSnippetView/Decimal_ID (0.00s) --- PASS: TestSnippetView/String_ID (0.00s) --- PASS: TestSnippetView/Empty_ID (0.00s) === RUN TestUserSignup === RUN TestUserSignup/Valid_submission === RUN TestUserSignup/Invalid_CSRF_Token === RUN TestUserSignup/Empty_name === RUN TestUserSignup/Empty_email === RUN TestUserSignup/Empty_password === RUN TestUserSignup/Invalid_email === RUN TestUserSignup/Short_password === RUN TestUserSignup/Duplicate_email --- PASS: TestUserSignup (0.01s) --- PASS: TestUserSignup/Valid_submission (0.00s) --- PASS: TestUserSignup/Invalid_CSRF_Token (0.00s) --- PASS: TestUserSignup/Empty_name (0.00s) --- PASS: TestUserSignup/Empty_email (0.00s) --- PASS: TestUserSignup/Empty_password (0.00s) --- PASS: TestUserSignup/Invalid_email (0.00s) --- PASS: TestUserSignup/Short_password (0.00s) --- PASS: TestUserSignup/Duplicate_email (0.00s) === RUN TestCommonHeaders --- PASS: TestCommonHeaders (0.00s) === RUN TestHumanDate === RUN TestHumanDate/UTC === RUN TestHumanDate/Empty === RUN TestHumanDate/CET --- PASS: TestHumanDate (0.00s) --- PASS: TestHumanDate/UTC (0.00s) --- PASS: TestHumanDate/Empty (0.00s) --- PASS: TestHumanDate/CET (0.00s) PASS ok snippetbox/cmd/web 0.017s === RUN TestUserModelExists user_test.go:11: models: skipping integration test --- SKIP: TestUserModelExists (0.00s) PASS ok snippetbox/internal/models 0.002s
Profiling test coverage
A great feature of the go test tool is the metrics and visualizations that it provides for test coverage.
Go ahead and try running the tests in our project using the -cover flag like so:
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -cover ./... ? snippetbox/ui [no test files] snippetbox/internal/models/mocks coverage: 0.0% of statements snippetbox/internal/validator coverage: 0.0% of statements snippetbox/internal/assert coverage: 0.0% of statements ok snippetbox/cmd/web 0.021s coverage: 44.1% of statements ok snippetbox/internal/models 0.215s coverage: 6.5% of statements
From the results here we can see that 44.1% of the statements in our cmd/web package are executed during our tests, and for our internal/models package the figure is 6.5%.
We can get a more detailed breakdown of test coverage by method and function by using the -coverprofile flag like so:
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -coverprofile=/tmp/cover_profile.out ./... ? snippetbox/ui [no test files] snippetbox/internal/models/mocks coverage: 0.0% of statements snippetbox/internal/assert coverage: 0.0% of statements snippetbox/internal/validator coverage: 0.0% of statements ok snippetbox/cmd/web 0.021s coverage: 44.1% of statements ok snippetbox/internal/models 0.172s coverage: 6.5% of statements
This will execute your tests as normal and — if all your tests pass — it will then write a coverage profile to a specific location.
You can then view the coverage profile by using the go tool cover command like so:
zzh@ZZHPC:/zdata/Github/snippetbox$ go tool cover -func=/tmp/cover_profile.out snippetbox/cmd/web/handlers.go:12: ping 100.0% snippetbox/cmd/web/handlers.go:16: home 0.0% snippetbox/cmd/web/handlers.go:29: snippetView 92.3% snippetbox/cmd/web/handlers.go:59: snippetCreate 0.0% snippetbox/cmd/web/handlers.go:69: snippetCreatePost 0.0% snippetbox/cmd/web/handlers.go:108: userSignup 100.0% snippetbox/cmd/web/handlers.go:114: userSignupPost 88.5% snippetbox/cmd/web/handlers.go:166: userLogin 0.0% snippetbox/cmd/web/handlers.go:172: userLoginPost 0.0% snippetbox/cmd/web/handlers.go:224: userLogoutPost 0.0% snippetbox/cmd/web/helpers.go:14: serverError 0.0% snippetbox/cmd/web/helpers.go:24: clientError 0.0% snippetbox/cmd/web/helpers.go:28: render 58.3% snippetbox/cmd/web/helpers.go:53: newTemplateData 100.0% snippetbox/cmd/web/helpers.go:63: decodePostForm 50.0% snippetbox/cmd/web/helpers.go:90: isAuthenticated 75.0% snippetbox/cmd/web/main.go:29: main 0.0% snippetbox/cmd/web/main.go:89: openDB 0.0% snippetbox/cmd/web/middleware.go:11: commonHeaders 100.0% snippetbox/cmd/web/middleware.go:28: noSurf 100.0% snippetbox/cmd/web/middleware.go:39: logRequest 100.0% snippetbox/cmd/web/middleware.go:54: recoverPanic 66.7% snippetbox/cmd/web/middleware.go:74: requireAuthentication 16.7% snippetbox/cmd/web/middleware.go:91: authenticate 38.5% snippetbox/cmd/web/routes.go:10: routes 100.0% snippetbox/cmd/web/template.go:22: humanDate 100.0% snippetbox/cmd/web/template.go:35: newTemplateCache 83.3% snippetbox/internal/assert/assert.go:8: Equal 66.7% snippetbox/internal/assert/assert.go:16: StringContains 66.7% snippetbox/internal/assert/assert.go:24: NilError 66.7% snippetbox/internal/models/mocks/snippet.go:18: Insert 0.0% snippetbox/internal/models/mocks/snippet.go:22: Get 100.0% snippetbox/internal/models/mocks/snippet.go:31: Latest 0.0% snippetbox/internal/models/mocks/user.go:7: Insert 100.0% snippetbox/internal/models/mocks/user.go:16: Authenticate 0.0% snippetbox/internal/models/mocks/user.go:24: Exists 0.0% snippetbox/internal/models/snippet.go:32: Insert 0.0% snippetbox/internal/models/snippet.go:56: Get 0.0% snippetbox/internal/models/snippet.go:89: Latest 0.0% snippetbox/internal/models/user.go:34: Insert 0.0% snippetbox/internal/models/user.go:64: Authenticate 0.0% snippetbox/internal/models/user.go:92: Exists 100.0% snippetbox/internal/validator/validator.go:24: Valid 100.0% snippetbox/internal/validator/validator.go:29: AddNonFieldError 0.0% snippetbox/internal/validator/validator.go:35: AddFieldError 100.0% snippetbox/internal/validator/validator.go:49: CheckField 100.0% snippetbox/internal/validator/validator.go:56: NotEmpty 100.0% snippetbox/internal/validator/validator.go:61: MinChars 100.0% snippetbox/internal/validator/validator.go:66: MaxChars 0.0% snippetbox/internal/validator/validator.go:71: PermittedValue 0.0% snippetbox/internal/validator/validator.go:76: Match 100.0% total: (statements) 39.1%
An alternative and more visual way to view the coverage profile is to use the -html flag instead of -func .
zzh@ZZHPC:/zdata/Github/snippetbox$ go tool cover -html=/tmp/cover_profile.out
This will open a browser window containing a navigable and highlighted representation of your code, similar to this:
You can take this a step further and use the -covermode=count option when running go test like so:
zzh@ZZHPC:/zdata/Github/snippetbox$ go test -covermode=count -coverprofile=/tmp/cover_profile.out ./... zzh@ZZHPC:/zdata/Github/snippetbox$ go tool cover -html=/tmp/cover_profile.out
Instead of just highlighting the statements in green and red, using -covermode=count makes the coverage profile record the exact number of times that each statement is executed during the tests.
When viewed in the browser, statements which are executed more frequently are then shown in a more saturated shade of green, similar to this:
Note: If you’re running some of your tests in parallel, you should use the -covermode=atomic flag (instead of -covermode=count ) to ensure an accurate count.
Add a debug mode
If you’ve used web frameworks for other languages, like Django or Laravel, then you might be familiar with the idea of a ‘debug’ mode where detailed errors are displayed to the user in a HTTP response instead of a generic "Internal Server Error" message.
You goal in this exercise is to set up a similar ‘debug mode’ for our application, which can be enabled by using the -debug flag like so:
$ go run ./cmd/web -debug
When running in debug mode, any detailed errors and stack traces should be displayed in the browser similar to this:
Step 1
Create a new command line flag with the name debug and a default value of false . Then make the value from this command-line flag available to your handlers via the application struct.
Step 2
Go to the cmd/web/helpers.go file and update the serverError() helper so that it renders a detailed error message and stack trace in a HTTP response if — and only if — the debug flag has been set. Otherwise send a generic error message as normal. You can get the stack trace using the debug.Stack() function.
package main import ( "bytes" "errors" "fmt" "net/http" "runtime/debug" "time" "github.com/go-playground/form/v4" "github.com/justinas/nosurf" ) func (app *application) serverError(w http.ResponseWriter, r *http.Request, err error) { var ( method = r.Method uri = r.URL.RequestURI() trace = string(debug.Stack()) ) app.logger.Error(err.Error(), "method", method, "uri", uri) if app.debug { body := fmt.Sprintf("%s\n\n%s", err, trace) http.Error(w, body, http.StatusInternalServerError) return } http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }
Step 3
Try out the change. Run the application and force a runtime error by using a DSN without the parseTime=true parameter:
zzh@ZZHPC:/zdata/Github/snippetbox$ go run ./cmd/web -debug -dsn="zeb:zebpwd@tcp(localhost:3306)/snippetbox"
Visiting https://localhost:4000/ should result in a response like this:
Running the application again without the -debug flag should result in a generic "Internal Server Error" message.
Redirect user appropriately after login
If an unauthenticated user tries to visit GET /account/view they will be redirected to the login page. Then after logging in successfully, they will be redirected to the GET /snippet/create form. This is awkward and confusing for the user, as they end up on a different page to where they originally wanted to go.
Your goal in this exercise is to update the application so that users are redirected to the page they were originally trying to visit after logging in.
Step 1
Update the requireAuthentication() middleware so that, before an unauthenticated user is redirected to the login page, the URL path that they are trying to visit is added to their session data.
func (app *application) requireAuthentication(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // If the user is not authenticated, redirect them to the login page and return from the // middleware chain so that no subsequent handlers in the chain are executed. if !app.isAuthenticated(r) { // Add the path that the user is trying to access to their session data. app.sessionManager.Put(r.Context(), "redirectPathAfterLogin", r.URL.Path) http.Redirect(w, r, "/user/login", http.StatusSeeOther) return } // Otherwise set the "Cache-Control: no-store" header so that pages require authtication // are not stored in the user's browser cache (or other intermediary cache). w.Header().Add("Cache-Control", "no-store") next.ServeHTTP(w, r) }) }
Step 2
Update the userLogin handler to check the user’s session for a URL path after they successfully log in. If one exists, remove it from the session data and redirect the user to that URL path. Otherwise, default to redirecting the user to /snippet/create .
func (app *application) userLoginPost(w http.ResponseWriter, r *http.Request) { var form userLoginForm err := app.decodePostForm(r, &form) if err != nil { app.clientError(w, http.StatusBadRequest) return } form.CheckField(validator.NotEmpty(form.Email), "email", "This field cannot be empty.") form.CheckField(validator.Match(form.Email, validator.EmailRX), "email", "This field must be a valid email address.") form.CheckField(validator.NotEmpty(form.Password), "password", "This field cannot be blank.") if !form.Valid() { data := app.newTemplateData(r) data.Form = form app.render(w, r, http.StatusUnprocessableEntity, "login.html", data) return } id, err := app.user.Authenticate(form.Email, form.Password) if err != nil { if errors.Is(err, models.ErrInvalidCredentials) { form.AddNonFieldError("Email or password is incorrect!") data := app.newTemplateData(r) data.Form = form app.render(w, r, http.StatusUnprocessableEntity, "login.html", data) } else { app.serverError(w, r, err) } return } // Use the RenewToken() method on the current session to change the session ID. It's a good // practice to generate a new session ID when the authentication state or privilage level // changes for the user (e.g. login and logout operations). err = app.sessionManager.RenewToken(r.Context()) if err != nil { app.serverError(w, r, err) return } // Add the ID of the current user to the session, so that they are now 'logged in'. app.sessionManager.Put(r.Context(), "authenticatedUserID", id) // Use the PopString method to retrieve and remove a value from the session data in one step. // If no matching key exists this will return the empty string. path := app.sessionManager.PopString(r.Context(), "redirectPathAfterLogin") if path != "" { http.Redirect(w, r, path, http.StatusSeeOther) return } http.Redirect(w, r, "/snippet/create", http.StatusSeeOther) }
Implement a ‘Change Password’ feature
Your goal in this exercise is to add the facility for an authenticated user to change their password, using a form which looks similar to this:
During this exercise you should make sure to:
- Ask the user for their current password and verify that it matches the hashed password in the users table (to confirm it is actually them making the request).
- Hash their new password before updating the users table.
Step 1
Create two new routes and handlers:
- GET /account/password/update which maps to a new accountPasswordUpdate handler.
- POST /account/password/update which maps to a new accountPasswordUpdatePost handler.
Both routes should be restricted to authenticated users only.
func (app *application) accountPasswordUpdate(w http.ResponseWriter, r *http.Request) { // Some code will go here later... } func (app *application) accountPasswordUpdatePost(w http.ResponseWriter, r *http.Request) { // Some code will go here later... }
package main import ( "net/http" "snippetbox/ui" "github.com/justinas/alice" ) func (app *application) routes() http.Handler { mux := http.NewServeMux() mux.Handle("GET /static/", http.FileServerFS(ui.Files)) mux.HandleFunc("GET /ping", ping) // Unprotected routes using the "dynamic" middleware chain. dynamic := alice.New(app.sessionManager.LoadAndSave, noSurf, app.authenticate) mux.Handle("GET /{$}", dynamic.ThenFunc(app.home)) mux.Handle("GET /about", dynamic.ThenFunc(app.about)) mux.Handle("GET /snippet/view/{id}", dynamic.ThenFunc(app.snippetView)) mux.Handle("GET /user/signup", dynamic.ThenFunc(app.userSignup)) mux.Handle("POST /user/signup", dynamic.ThenFunc(app.userSignupPost)) mux.Handle("GET /user/login", dynamic.ThenFunc(app.userLogin)) mux.Handle("POST /user/login", dynamic.ThenFunc(app.userLoginPost)) // Protected (authenticated-only) routes using the "protected" middleware chain which includes // the requireAuthentication middleware. protected := dynamic.Append(app.requireAuthentication) mux.Handle("GET /snippet/create", protected.ThenFunc(app.snippetCreate)) mux.Handle("POST /snippet/create", protected.ThenFunc(app.snippetCreatePost)) mux.Handle("POST /user/logout", protected.ThenFunc(app.userLogoutPost)) mux.Handle("GET /account/view", protected.ThenFunc(app.accountView)) mux.Handle("GET /account/password/update", protected.ThenFunc(app.accountPasswordUpdate)) mux.Handle("POST /account/password/update", protected.ThenFunc(app.accountPasswordUpdatePost)) standard := alice.New(app.recoverPanic, app.logRequest, commonHeaders) return standard.Then(mux) }
Step 2
Create a new ui/html/pages/password.html file which contains the change password form. This form should:
- Have three fields: currentPassword , newPassword and newPasswordConfirmation .
- POST the form data to /account/password/update when submitted.
- Display errors for each of the fields in the event of a validation error.
- Not re-display passwords in the event of a validation error.
Then update the cmd/web/handlers.go file to include a new accountPasswordUpdateForm struct that you can parse the form data into, and update the accountPasswordUpdate handler to display this empty form.
{{define "title"}}Change Password{{end}} {{define "main"}} <h2>Change Password</h2> <form action="/account/password/update" method="POST" novalidate> <input type="hidden" name="csrf_token" value="{{.CSRFToken}}"> <div> <label>Current password:</label> {{with .Form.FieldErrors.currentPassword}} <label class="error">{{.}}</label> {{end}} <input type="password" name="currentPassword"> </div> <div> <label>New password:</label> {{with .Form.FieldErrors.newPassword}} <label class="error">{{.}}</label> {{end}} <input type="password" name="newPassword"> </div> <div> <label>Confirm new password:</label> {{with .Form.FieldErrors.newPasswordConfirmation}} <label class="error">{{.}}</label> {{end}} <input type="password" name="newPasswordConfirmation"> </div> <div> <input type="submit" value="Change password"> </div> </form> {{end}}
type accountPasswordUpdateForm struct { CurrentPassword string `form:"currentPassword"` NewPassword string `form:"newPassword"` NewPasswordConfirmation string `form:"newPasswordConfirmation"` validator.Validator `form:"-"` } func (app *application) accountPasswordUpdate(w http.ResponseWriter, r *http.Request) { data := app.newTemplateData(r) data.Form = accountPasswordUpdateForm{} app.render(w, r, http.StatusOK, "password.html", data) }
Step 3
Update the accountPasswordUpdatePost handler to carry out the following form validation checks, and re-display the form with the relevant error messages in the event of any failures.
- All three fields are required.
- The newPassword value must be at least 8 characters long.
- The newPassword and newPasswordConfirmation values must match.
func (app *application) accountPasswordUpdatePost(w http.ResponseWriter, r *http.Request) { var form accountPasswordUpdateForm err := app.decodePostForm(r, &form) if err != nil { app.clientError(w, http.StatusBadRequest) return } form.CheckField(validator.NotEmpty(form.CurrentPassword), "currentPassword", "Thif field cannot be empty.") form.CheckField(validator.NotEmpty(form.NewPassword), "newPassword", "This field cannot be empty.") form.CheckField(validator.MinChars(form.NewPassword, 8), "newPassword", "This field must be at least 8 characters long.") form.CheckField(validator.NotEmpty(form.NewPasswordConfirmation), "newPasswordConfirmation", "This field cannot be empty.") form.CheckField(form.NewPassword == form.NewPasswordConfirmation, "newPasswordConfirmation", "Passwords do not match.") if !form.Valid() { data := app.newTemplateData(r) data.Form = form app.render(w, r, http.StatusUnprocessableEntity, "password.html", data) return } }
Step 4
In your internal/models/users.go file create a new UserModel.PasswordUpdate() method with the following signature:
func (m *UserModel) UpdatePassword(id int, currentPassword, newPassword string) error
In this method:
1. Retrieve the user details for the user with the ID given by the id parameter from the database.
2. Check that the currentPassword value matches the hashed password for the user. If it doesn’t match, return an ErrInvalidCredentials error.
3. Otherwise, hash the newPassword value and update the hashed_password column in the users table for the relevant user.
Also update the UserModelInterface interface type to include the UpdatePassword() method that you’ve just created.
type UserModelInterface interface { Insert(name, email, password string) error Authenticate(email, password string) (int, error) Exists(id int) (bool, error) Get(id int) (User, error) UpdatePassword(id int, currentPassword, newPassword string) error } ... func (m *UserModel) UpdatePassword(id int, currentPassword, newPassword string) error { var currentHashedPassword []byte stmt := `SELECT hashed_password FROM user WHERE id = ?` err := m.DB.QueryRow(stmt, id).Scan(¤tHashedPassword) if err != nil { return err } err = bcrypt.CompareHashAndPassword(currentHashedPassword, []byte(currentPassword)) if err != nil { if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { return ErrInvalidCredentials } else { return err } } newHashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12) if err != nil { return err } stmt = `UPDATE user SET hashed_password = ? WHERE id = ?` _, err = m.DB.Exec(stmt, string(newHashedPassword), id) return err }
Step 5
Update the accountPasswordUpdatePost handler so that if the form is valid, it calls the UserModel.PasswordUpdate() method (remember, the user’s ID should be in the session data).
In the event of a models.ErrInvalidCredentials error, inform the user that they have entered the wrong value in the currentPassword form field. Otherwise, add a flash message to the user’s session saying that their password has been successfully changed and redirect them to their account page.
func (app *application) accountPasswordUpdatePost(w http.ResponseWriter, r *http.Request) { var form accountPasswordUpdateForm err := app.decodePostForm(r, &form) if err != nil { app.clientError(w, http.StatusBadRequest) return } form.CheckField(validator.NotEmpty(form.CurrentPassword), "currentPassword", "Thif field cannot be empty.") form.CheckField(validator.NotEmpty(form.NewPassword), "newPassword", "This field cannot be empty.") form.CheckField(validator.MinChars(form.NewPassword, 8), "newPassword", "This field must be at least 8 characters long.") form.CheckField(validator.NotEmpty(form.NewPasswordConfirmation), "newPasswordConfirmation", "This field cannot be empty.") form.CheckField(form.NewPassword == form.NewPasswordConfirmation, "newPasswordConfirmation", "Passwords do not match.") if !form.Valid() { data := app.newTemplateData(r) data.Form = form app.render(w, r, http.StatusUnprocessableEntity, "password.html", data) return } userID := app.sessionManager.GetInt(r.Context(), "authenticatedUserID") err = app.user.UpdatePassword(userID, form.CurrentPassword, form.NewPassword) if err != nil { if errors.Is(err, models.ErrInvalidCredentials) { form.AddFieldError("currentPassword", "Current password is incorrect.") data := app.newTemplateData(r) data.Form = form app.render(w, r, http.StatusUnprocessableEntity, "password.html", data) } else { app.serverError(w, r, err) } return } app.sessionManager.Put(r.Context(), "flash", "Your password has been updated!") http.Redirect(w, r, "/account/view", http.StatusSeeOther) }
Step 6
Update the account to include a link to the change password form, similar to this:
{{define "title"}}Account{{end}} {{define "main"}} <h2>Your Account</h2> {{with .User}} <table> <tr> <th>Name</th> <td>{{.Name}}</td> </tr> <tr> <th>Email</th> <td>{{.Email}}</td> </tr> <tr> <th>Joined</th> <td>{{humanDate .Created}}</td> </tr> <tr> <th>Password</th> <td><a href="/account/password/update">Change password</a></td> </tr> </table> {{end}} {{end}}