diff --git a/hbrtl/pgserver/auth.go b/hbrtl/pgserver/auth.go index 1190f4b..a54dbe4 100644 --- a/hbrtl/pgserver/auth.go +++ b/hbrtl/pgserver/auth.go @@ -1,7 +1,8 @@ // Copyright (c) 2026 Charles KWON OhJun (charleskwonohjun@gmail.com) // All rights reserved. -// auth.go — password / md5 authentication for the pgserver. +// auth.go — password / md5 / scram-sha-256 authentication for the +// pgserver. // // Roles + credentials live in an in-memory registry managed via // `PG_ADD_ROLE(name, password)` HB_FUNC (see register.go). At @@ -9,17 +10,19 @@ // that should be allowed in; `trust` mode bypasses lookup // entirely so single-user / dev setups don't need credentials. // -// SCRAM-SHA-256 is Phase 5.1 — pgx falls back to MD5 cleanly -// when the server advertises only md5, so v1.0 functional -// coverage is complete with the two methods here. +// SCRAM math lives in scram.go (RFC 5802 primitives, separately +// unit-tested). This file handles the wire flow and binds the +// per-session state. package pgserver import ( "crypto/md5" "crypto/rand" + "encoding/base64" "encoding/hex" "fmt" + "strings" "sync" "github.com/jackc/pgx/v5/pgproto3" @@ -77,9 +80,11 @@ func (s *session) authenticate() error { return s.authPassword() case "md5": return s.authMD5() + case "scram-sha-256", "scram": + return s.authSCRAM() default: s.sendError("28000", - "auth mode "+s.srv.cfg.AuthMode+" not implemented (use trust/password/md5)") + "auth mode "+s.srv.cfg.AuthMode+" not implemented (use trust/password/md5/scram-sha-256)") return errAuthRejected } } @@ -150,6 +155,143 @@ func (s *session) authMD5() error { return nil } +// authSCRAM runs the five-message SCRAM-SHA-256 exchange: +// +// server → AuthenticationSASL{Mechanisms: ["SCRAM-SHA-256"]} +// client → SASLInitialResponse{Mechanism, Data: client-first} +// server → AuthenticationSASLContinue{Data: server-first} +// client → SASLResponse{Data: client-final} +// server → AuthenticationSASLFinal{Data: "v="} +// server → AuthenticationOk +// +// The math lives in scram.go; this method is the wire-flow shell. +// Channel binding ("SCRAM-SHA-256-PLUS") is intentionally not +// advertised — adds rebind complexity over TLS for marginal +// benefit on a single-process server. +func (s *session) authSCRAM() error { + s.send(&pgproto3.AuthenticationSASL{ + AuthMechanisms: []string{"SCRAM-SHA-256"}, + }) + // pgproto3 multiplexes PasswordMessage, SASLInitialResponse, + // and SASLResponse onto the same byte tag ('p'); SetAuthType + // tells the Backend which one to decode the next 'p' frame as. + // Without these the receive loop unmarshals our handshake as + // a cleartext PasswordMessage and SCRAM fails before we even + // see the client-first. + if err := s.be.SetAuthType(pgproto3.AuthTypeSASL); err != nil { + s.sendError("XX000", "scram: cannot enter SASL state: "+err.Error()) + return errAuthRejected + } + + // Step 1 — receive client-first. + msg, err := s.be.Receive() + if err != nil { + return err + } + init, ok := msg.(*pgproto3.SASLInitialResponse) + if !ok { + s.sendError("28000", fmt.Sprintf("scram: expected SASLInitialResponse, got %T", msg)) + return errAuthRejected + } + if init.AuthMechanism != "SCRAM-SHA-256" { + s.sendError("28000", + "scram: unsupported mechanism "+init.AuthMechanism+" (offer SCRAM-SHA-256)") + return errAuthRejected + } + clientFirst := string(init.Data) + clientFirstBare, okBare := scramClientFirstBare(clientFirst) + if !okBare { + s.sendError("28000", "scram: malformed client-first message") + return errAuthRejected + } + attrs := scramParseAttrs(clientFirstBare) + clientNonce := attrs["r"] + if clientNonce == "" { + s.sendError("28000", "scram: client-first missing r= nonce") + return errAuthRejected + } + + // Step 2 — generate server-nonce + salt + iterations, send + // server-first. The combined nonce is client-nonce ++ + // server-nonce; the client must echo it back verbatim in + // client-final or we reject. + var serverNonceRaw [18]byte + if _, err := rand.Read(serverNonceRaw[:]); err != nil { + s.sendError("XX000", "scram: rng failure") + return err + } + serverNonce := base64.StdEncoding.EncodeToString(serverNonceRaw[:]) + var salt [16]byte + if _, err := rand.Read(salt[:]); err != nil { + s.sendError("XX000", "scram: rng failure") + return err + } + combinedNonce := clientNonce + serverNonce + serverFirst := scramServerFirst(combinedNonce, salt[:], scramIterations) + s.send(&pgproto3.AuthenticationSASLContinue{Data: []byte(serverFirst)}) + // Switch the decoder into SASLContinue mode so the next 'p' + // frame is parsed as SASLResponse (client-final), not as a + // fresh PasswordMessage. + if err := s.be.SetAuthType(pgproto3.AuthTypeSASLContinue); err != nil { + s.sendError("XX000", "scram: cannot enter SASLContinue state: "+err.Error()) + return errAuthRejected + } + + // Step 3 — receive client-final. + msg, err = s.be.Receive() + if err != nil { + return err + } + final, ok := msg.(*pgproto3.SASLResponse) + if !ok { + s.sendError("28000", fmt.Sprintf("scram: expected SASLResponse, got %T", msg)) + return errAuthRejected + } + clientFinal := string(final.Data) + // Split off the proof — everything before ",p=" is the + // client-final-without-proof that goes into the AuthMessage. + pIdx := strings.LastIndex(clientFinal, ",p=") + if pIdx < 0 { + s.sendError("28000", "scram: client-final missing p= proof") + return errAuthRejected + } + clientFinalNoProof := clientFinal[:pIdx] + clientProofB64 := clientFinal[pIdx+3:] + finalAttrs := scramParseAttrs(clientFinalNoProof) + if finalAttrs["r"] != combinedNonce { + s.sendError("28P01", "scram: nonce mismatch") + return errAuthRejected + } + + // Step 4 — resolve role + compute the same SaltedPassword the + // client did, then verify the ClientProof against StoredKey. + // Constant-time verify hides whether the role exists vs. the + // password was wrong; PG itself happily distinguishes the two, + // but we err on the safe side. + r := lookupRole(s.user) + authMsg := scramAuthMessage(clientFirstBare, serverFirst, clientFinalNoProof) + var saltedPassword []byte + if r != nil { + saltedPassword = scramSaltedPassword(r.PasswordPlain, salt[:], scramIterations) + } else { + // Compute a dummy salted password so we still spend the + // PBKDF2 cycles — keeps the timing roughly constant + // regardless of role existence. + saltedPassword = scramSaltedPassword("", salt[:], scramIterations) + } + if r == nil || !scramVerifyClientProof(saltedPassword, authMsg, clientProofB64) { + s.sendError("28P01", + fmt.Sprintf("SCRAM authentication failed for user %q", s.user)) + return errAuthRejected + } + + // Step 5 — send server signature then AuthenticationOk. + serverSig := scramServerSignature(saltedPassword, authMsg) + s.send(&pgproto3.AuthenticationSASLFinal{Data: []byte("v=" + serverSig)}) + s.send(&pgproto3.AuthenticationOk{}) + return nil +} + // md5Challenge reproduces libpq's md5 client computation so we // can compare against the value the client sent. func md5Challenge(password, user string, salt []byte) string { diff --git a/hbrtl/pgserver/pgserver_test.go b/hbrtl/pgserver/pgserver_test.go index 8ab6fbe..0150677 100644 --- a/hbrtl/pgserver/pgserver_test.go +++ b/hbrtl/pgserver/pgserver_test.go @@ -5,6 +5,7 @@ package pgserver import ( "bytes" + "encoding/base64" "strconv" "strings" "testing" @@ -186,3 +187,77 @@ func TestCommandTagFor(t *testing.T) { } _ = strconv.Itoa // keep import; will be used in Phase 3 with row counts } + +// TestSCRAMParseClientFirst verifies the gs2-header strip + attr +// parse for the SCRAM client-first message. Vector matches what +// libpq + pgx + JDBC all emit (channel-binding flag "n", empty +// authzid, user attribute, random nonce). +func TestSCRAMParseClientFirst(t *testing.T) { + bare, ok := scramClientFirstBare("n,,n=alice,r=ClientNonce123") + if !ok { + t.Fatal("scramClientFirstBare rejected a valid client-first") + } + if bare != "n=alice,r=ClientNonce123" { + t.Errorf("client-first-bare wrong: %q", bare) + } + attrs := scramParseAttrs(bare) + if attrs["n"] != "alice" || attrs["r"] != "ClientNonce123" { + t.Errorf("attr parse wrong: %+v", attrs) + } +} + +// TestSCRAMRoundTrip exercises the full SCRAM math: server picks +// salt+nonce+iter, client (us, here) computes its proof, server +// verifies it, server emits ServerSignature, client verifies that. +// Pinning the verify path catches any regression in PBKDF2/HMAC +// wiring without needing a live psql process. +func TestSCRAMRoundTrip(t *testing.T) { + const password = "swordfish" + const clientNonce = "rOprNGfwEbeRWgbNEkqO" // RFC 7677-style fixed vector + const serverNonce = "9zphn2KvL3K5dlqJpvBz" // arbitrary + salt := []byte{0xa6, 0x3d, 0x18, 0x4f, 0x05, 0x12, 0xc1, 0xd7, + 0x88, 0x73, 0xea, 0xb6, 0x91, 0x04, 0x7c, 0x80} + iter := scramIterations + + clientFirstBare := "n=alice,r=" + clientNonce + combined := clientNonce + serverNonce + serverFirst := scramServerFirst(combined, salt, iter) + + // Client-side proof computation (mirror of what libpq does). + saltedPwd := scramSaltedPassword(password, salt, iter) + clientFinalNoProof := "c=biws,r=" + combined + authMsg := scramAuthMessage(clientFirstBare, serverFirst, clientFinalNoProof) + + clientKey := scramHMAC(saltedPwd, []byte("Client Key")) + storedKey := scramH(clientKey) + clientSig := scramHMAC(storedKey, authMsg) + clientProof := scramXOR(clientKey, clientSig) + clientProofB64 := base64.StdEncoding.EncodeToString(clientProof) + + // Server-side verify. + if !scramVerifyClientProof(saltedPwd, authMsg, clientProofB64) { + t.Fatal("scramVerifyClientProof rejected a correctly-computed proof") + } + + // Server emits its signature; client must accept it. + serverSigB64 := scramServerSignature(saltedPwd, authMsg) + serverSigDecoded, err := base64.StdEncoding.DecodeString(serverSigB64) + if err != nil || len(serverSigDecoded) != 32 { + t.Fatalf("server signature malformed: %q err=%v", serverSigB64, err) + } + serverKey := scramHMAC(saltedPwd, []byte("Server Key")) + expectedSig := scramHMAC(serverKey, authMsg) + if !bytes.Equal(serverSigDecoded, expectedSig) { + t.Errorf("server signature mismatch: got %x want %x", serverSigDecoded, expectedSig) + } + + // Wrong-password proof must reject. + wrongSalted := scramSaltedPassword("wrong", salt, iter) + wrongClientKey := scramHMAC(wrongSalted, []byte("Client Key")) + wrongStoredKey := scramH(wrongClientKey) + wrongSig := scramHMAC(wrongStoredKey, authMsg) + wrongProof := base64.StdEncoding.EncodeToString(scramXOR(wrongClientKey, wrongSig)) + if scramVerifyClientProof(saltedPwd, authMsg, wrongProof) { + t.Error("scramVerifyClientProof accepted a wrong-password proof") + } +} diff --git a/hbrtl/pgserver/scram.go b/hbrtl/pgserver/scram.go new file mode 100644 index 0000000..f6131c3 --- /dev/null +++ b/hbrtl/pgserver/scram.go @@ -0,0 +1,164 @@ +// Copyright (c) 2026 Charles KWON OhJun (charleskwonohjun@gmail.com) +// All rights reserved. + +// scram.go — SCRAM-SHA-256 (RFC 5802) primitives for the pgserver +// auth path. PostgreSQL 14+ default; libpq, pgx, and JDBC all +// prefer it over MD5 when offered. +// +// Storage model: roles in the registry hold the plaintext password +// (same as the MD5 path). For each authentication we generate a +// fresh salt + iteration count and derive SaltedPassword on the +// fly. This is functionally equivalent to PG's "scram-stored" +// secrets and keeps the code-path identical for both clients — +// the wire output matches RFC 5802 byte for byte. + +package pgserver + +import ( + "crypto/hmac" + "crypto/pbkdf2" + "crypto/sha256" + "encoding/base64" + "fmt" + "strconv" + "strings" +) + +// scramIterations is the PBKDF2 iteration count we advertise to +// clients. PG 15+ uses 4096 (the SCRAM minimum) by default; +// matching that keeps handshake latency negligible while still +// satisfying RFC 5802. +const scramIterations = 4096 + +// scramSaltedPassword runs PBKDF2-HMAC-SHA256 to derive the +// 32-byte SaltedPassword from a UTF-8 password + salt. Both the +// client and the server compute this independently from the same +// inputs; if they disagree the resulting proofs won't match. +func scramSaltedPassword(password string, salt []byte, iter int) []byte { + key, err := pbkdf2.Key(sha256.New, password, salt, iter, sha256.Size) + if err != nil { + // PBKDF2 only errors for nonsense params (negative iter, + // zero-length key). Our call site uses fixed constants so + // this branch is unreachable; surface as a panic so a + // future caller doesn't silently get a nil slice. + panic(fmt.Sprintf("scram: pbkdf2 failure: %v", err)) + } + return key +} + +// scramHMAC computes HMAC-SHA256(key, data) — RFC 5802 § 2.2. +func scramHMAC(key, data []byte) []byte { + m := hmac.New(sha256.New, key) + m.Write(data) + return m.Sum(nil) +} + +// scramH computes SHA-256(input) — RFC 5802 § 2.2 H(). +func scramH(input []byte) []byte { + sum := sha256.Sum256(input) + return sum[:] +} + +// scramXOR returns a XOR b. Both slices must be the same length; +// the caller controls them (HMAC outputs are fixed 32B). +func scramXOR(a, b []byte) []byte { + out := make([]byte, len(a)) + for i := range a { + out[i] = a[i] ^ b[i] + } + return out +} + +// scramClientFirstBare extracts the bare GS2-header-free portion +// of the client-first message. Spec form: +// +// client-first-message = gs2-header client-first-message-bare +// gs2-header = gs2-cbind-flag "," [ authzid ] "," +// client-first-bare = "n=" saslname "," "r=" c-nonce +// +// For PG the channel-binding flag is always "n" (none), so the +// header is the two-byte "n," followed by an empty authzid then +// another ",". We strip those three leading bytes. +func scramClientFirstBare(clientFirst string) (bare string, ok bool) { + if !strings.HasPrefix(clientFirst, "n,") && !strings.HasPrefix(clientFirst, "y,") { + return "", false + } + idx := strings.Index(clientFirst, ",") + if idx < 0 { + return "", false + } + rest := clientFirst[idx+1:] + idx2 := strings.Index(rest, ",") + if idx2 < 0 { + return "", false + } + return rest[idx2+1:], true +} + +// scramParseAttrs splits "k=v,k=v,..." into a map. Spec attribute +// values can't contain "=" or "," so a plain split is safe. +func scramParseAttrs(s string) map[string]string { + out := map[string]string{} + for _, kv := range strings.Split(s, ",") { + i := strings.IndexByte(kv, '=') + if i <= 0 { + continue + } + out[kv[:i]] = kv[i+1:] + } + return out +} + +// scramServerFirst builds the server-first-message: +// +// "r=,s=,i=" +// +// Returned as a string (becomes Data on AuthenticationSASLContinue). +func scramServerFirst(combinedNonce string, salt []byte, iter int) string { + return "r=" + combinedNonce + + ",s=" + base64.StdEncoding.EncodeToString(salt) + + ",i=" + strconv.Itoa(iter) +} + +// scramAuthMessage builds the SCRAM AuthMessage: +// +// AuthMessage = client-first-bare + "," + server-first + +// "," + client-final-without-proof +// +// Both sides compute this identically; it's the input both the +// client proof and the server signature HMAC over. +func scramAuthMessage(clientFirstBare, serverFirst, clientFinalNoProof string) []byte { + return []byte(clientFirstBare + "," + serverFirst + "," + clientFinalNoProof) +} + +// scramVerifyClientProof reconstructs ClientKey from the client's +// proof and confirms that H(ClientKey) == StoredKey. Returns true +// iff the proof is valid. +// +// ClientSignature = HMAC(StoredKey, AuthMessage) +// ClientKey' = ClientProof XOR ClientSignature +// verify H(ClientKey') == StoredKey +func scramVerifyClientProof(saltedPassword []byte, authMessage []byte, clientProofB64 string) bool { + clientProof, err := base64.StdEncoding.DecodeString(clientProofB64) + if err != nil || len(clientProof) != sha256.Size { + return false + } + clientKey := scramHMAC(saltedPassword, []byte("Client Key")) + storedKey := scramH(clientKey) + clientSig := scramHMAC(storedKey, authMessage) + recovered := scramXOR(clientProof, clientSig) + return hmac.Equal(scramH(recovered), storedKey) +} + +// scramServerSignature computes the server's proof to be returned +// in AuthenticationSASLFinal: +// +// ServerKey = HMAC(SaltedPassword, "Server Key") +// ServerSignature = HMAC(ServerKey, AuthMessage) +// +// Returns base64 — wire format expects "v=" + base64(signature). +func scramServerSignature(saltedPassword []byte, authMessage []byte) string { + serverKey := scramHMAC(saltedPassword, []byte("Server Key")) + sig := scramHMAC(serverKey, authMessage) + return base64.StdEncoding.EncodeToString(sig) +} diff --git a/tests/pgserver/run.sh b/tests/pgserver/run.sh index 13a963a..2ff3023 100755 --- a/tests/pgserver/run.sh +++ b/tests/pgserver/run.sh @@ -144,6 +144,40 @@ else fail "MD5 auth: correct password accepted" "$good" fi +# 4b) SCRAM-SHA-256 — restart server in scram mode and verify both +# the rejection and success paths. PG 14+ clients prefer SCRAM when +# offered, so this is the most-exercised auth path in production. +kill $SERVER_PID 2>/dev/null +wait 2>/dev/null + +cat > "$work/scram.prg" </dev/null 2>&1 +"$work/scram" & +SERVER_PID=$! +sleep 1 +trap "kill $SERVER_PID 2>/dev/null; rm -rf '$work'" EXIT + +scram_bad="$(PGPASSWORD=wrong psql "postgres://alice@127.0.0.1:$PORT/alice?sslmode=disable" \ + -c "SELECT 1" 2>&1 | head -1 || true)" +if echo "$scram_bad" | grep -qi "SCRAM authentication failed"; then + ok "SCRAM-SHA-256: wrong password rejected" +else + fail "SCRAM-SHA-256: wrong password rejected" "$scram_bad" +fi + +scram_good="$(PGPASSWORD=swordfish psql "postgres://alice@127.0.0.1:$PORT/alice?sslmode=disable" \ + -c "SELECT 'scram-ok' AS x" -At 2>&1 || true)" +if echo "$scram_good" | grep -q "^scram-ok$"; then + ok "SCRAM-SHA-256: correct password accepted" +else + fail "SCRAM-SHA-256: correct password accepted" "$scram_good" +fi + # 5) TLS — restart server with self-signed cert + allowlist and # connect via psql sslmode=require. kill $SERVER_PID 2>/dev/null