PG14+ clients (libpq, pgx, JDBC) prefer SCRAM over MD5 when offered;
this lands the five-message exchange (SASL / SASLInitialResponse /
SASLContinue / SASLResponse / SASLFinal) so they get their preferred
path. MD5 stays as the universal fallback.
Storage stays plaintext in the in-memory role registry — per-auth we
generate a fresh salt + iter, derive SaltedPassword on the fly. Same
net security as the existing MD5 path, while matching wire output to
RFC 5802 byte for byte.
Critical detail: pgproto3's Backend multiplexes PasswordMessage,
SASLInitialResponse, and SASLResponse onto the same 'p' byte tag.
Without SetAuthType() the decoder picks PasswordMessage and the
handshake fails immediately. Switch state to AuthTypeSASL before
the client-first receive and AuthTypeSASLContinue before the
client-final receive.
Verified:
* SCRAM math (PBKDF2 / HMAC / proof verify / server signature)
via pinned unit test
* Live psql round-trip — correct password accepted, wrong password
rejected with proper SQLSTATE 28P01
* All 6 mandatory gates green (go test, SQL 43/43, compat 56/56,
std.ch 17/17, FRB 7/7, pgserver 11/11)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
311 lines
10 KiB
Go
311 lines
10 KiB
Go
// Copyright (c) 2026 Charles KWON OhJun (charleskwonohjun@gmail.com)
|
|
// All rights reserved.
|
|
|
|
// auth.go — password / md5 / scram-sha-256 authentication for the
|
|
// pgserver.
|
|
//
|
|
// Roles + credentials live in an in-memory registry managed via
|
|
// `PG_ADD_ROLE(name, password)` HB_FUNC (see register.go). At
|
|
// startup the bootstrap PRG calls PG_ADD_ROLE for every account
|
|
// that should be allowed in; `trust` mode bypasses lookup
|
|
// entirely so single-user / dev setups don't need credentials.
|
|
//
|
|
// SCRAM math lives in scram.go (RFC 5802 primitives, separately
|
|
// unit-tested). This file handles the wire flow and binds the
|
|
// per-session state.
|
|
|
|
package pgserver
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/jackc/pgx/v5/pgproto3"
|
|
)
|
|
|
|
// role captures a stored credential. PasswordPlain is held so
|
|
// cleartext-password mode (the simplest path) doesn't need a
|
|
// separate verification table. MD5 mode computes the canonical
|
|
// hash from PasswordPlain at challenge time — matches what
|
|
// Postgres does internally when md5 is configured against a
|
|
// scram-stored password before users opt in to SCRAM.
|
|
type role struct {
|
|
Name string
|
|
PasswordPlain string
|
|
}
|
|
|
|
var (
|
|
roleMu sync.RWMutex
|
|
roleMap = map[string]*role{}
|
|
)
|
|
|
|
// AddRole registers a user + password. Replaces any prior entry
|
|
// with the same name (so a bootstrap PRG can re-add roles on
|
|
// restart without first DROPping them).
|
|
func AddRole(name, password string) {
|
|
roleMu.Lock()
|
|
defer roleMu.Unlock()
|
|
roleMap[name] = &role{Name: name, PasswordPlain: password}
|
|
}
|
|
|
|
// RemoveRole drops a registered user. No-op if unknown.
|
|
func RemoveRole(name string) {
|
|
roleMu.Lock()
|
|
defer roleMu.Unlock()
|
|
delete(roleMap, name)
|
|
}
|
|
|
|
// lookupRole resolves a role by name. Returns nil if absent.
|
|
func lookupRole(name string) *role {
|
|
roleMu.RLock()
|
|
defer roleMu.RUnlock()
|
|
return roleMap[name]
|
|
}
|
|
|
|
// authenticate runs the auth handshake based on the server's
|
|
// configured AuthMode. The client identity (s.user) has already
|
|
// been recorded from the StartupMessage; we look it up in the
|
|
// role registry and execute the appropriate challenge.
|
|
func (s *session) authenticate() error {
|
|
switch s.srv.cfg.AuthMode {
|
|
case "", "trust":
|
|
s.send(&pgproto3.AuthenticationOk{})
|
|
return nil
|
|
case "password":
|
|
return s.authPassword()
|
|
case "md5":
|
|
return s.authMD5()
|
|
case "scram-sha-256", "scram":
|
|
return s.authSCRAM()
|
|
default:
|
|
s.sendError("28000",
|
|
"auth mode "+s.srv.cfg.AuthMode+" not implemented (use trust/password/md5/scram-sha-256)")
|
|
return errAuthRejected
|
|
}
|
|
}
|
|
|
|
// authPassword does the cleartext-password exchange. The wire
|
|
// payload is plaintext, so this is intended for TLS-protected
|
|
// links only — emit a warning if the connection isn't tls.Server
|
|
// (deferred until Phase 6 wires up TLS detection on session).
|
|
func (s *session) authPassword() error {
|
|
s.send(&pgproto3.AuthenticationCleartextPassword{})
|
|
msg, err := s.be.Receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
pwd, ok := msg.(*pgproto3.PasswordMessage)
|
|
if !ok {
|
|
s.sendError("28000", "expected PasswordMessage")
|
|
return errAuthRejected
|
|
}
|
|
r := lookupRole(s.user)
|
|
if r == nil || r.PasswordPlain != pwd.Password {
|
|
s.sendError("28P01",
|
|
fmt.Sprintf("password authentication failed for user %q", s.user))
|
|
return errAuthRejected
|
|
}
|
|
s.send(&pgproto3.AuthenticationOk{})
|
|
return nil
|
|
}
|
|
|
|
// authMD5 implements the libpq MD5 password challenge:
|
|
//
|
|
// server sends: AuthenticationMD5Password{Salt: 4 random bytes}
|
|
// client returns: "md5" || md5_hex( md5_hex(password || user) || salt )
|
|
// server verifies by recomputing with the stored plaintext
|
|
//
|
|
// MD5 is no longer recommended (libpq 14+ default is SCRAM), but
|
|
// every PG client implements it as a fallback. Adequate for v1.0
|
|
// over loopback or a trusted network; Phase 5.1 lands SCRAM.
|
|
func (s *session) authMD5() error {
|
|
var salt [4]byte
|
|
if _, err := rand.Read(salt[:]); err != nil {
|
|
s.sendError("XX000", "auth: rng failure")
|
|
return err
|
|
}
|
|
s.send(&pgproto3.AuthenticationMD5Password{Salt: salt})
|
|
msg, err := s.be.Receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
pwd, ok := msg.(*pgproto3.PasswordMessage)
|
|
if !ok {
|
|
s.sendError("28000", "expected PasswordMessage")
|
|
return errAuthRejected
|
|
}
|
|
r := lookupRole(s.user)
|
|
if r == nil {
|
|
s.sendError("28P01",
|
|
fmt.Sprintf("md5 authentication failed for user %q", s.user))
|
|
return errAuthRejected
|
|
}
|
|
expected := md5Challenge(r.PasswordPlain, s.user, salt[:])
|
|
if pwd.Password != expected {
|
|
s.sendError("28P01",
|
|
fmt.Sprintf("md5 authentication failed for user %q", s.user))
|
|
return errAuthRejected
|
|
}
|
|
s.send(&pgproto3.AuthenticationOk{})
|
|
return nil
|
|
}
|
|
|
|
// authSCRAM runs the five-message SCRAM-SHA-256 exchange:
|
|
//
|
|
// server → AuthenticationSASL{Mechanisms: ["SCRAM-SHA-256"]}
|
|
// client → SASLInitialResponse{Mechanism, Data: client-first}
|
|
// server → AuthenticationSASLContinue{Data: server-first}
|
|
// client → SASLResponse{Data: client-final}
|
|
// server → AuthenticationSASLFinal{Data: "v=<server-sig>"}
|
|
// server → AuthenticationOk
|
|
//
|
|
// The math lives in scram.go; this method is the wire-flow shell.
|
|
// Channel binding ("SCRAM-SHA-256-PLUS") is intentionally not
|
|
// advertised — adds rebind complexity over TLS for marginal
|
|
// benefit on a single-process server.
|
|
func (s *session) authSCRAM() error {
|
|
s.send(&pgproto3.AuthenticationSASL{
|
|
AuthMechanisms: []string{"SCRAM-SHA-256"},
|
|
})
|
|
// pgproto3 multiplexes PasswordMessage, SASLInitialResponse,
|
|
// and SASLResponse onto the same byte tag ('p'); SetAuthType
|
|
// tells the Backend which one to decode the next 'p' frame as.
|
|
// Without these the receive loop unmarshals our handshake as
|
|
// a cleartext PasswordMessage and SCRAM fails before we even
|
|
// see the client-first.
|
|
if err := s.be.SetAuthType(pgproto3.AuthTypeSASL); err != nil {
|
|
s.sendError("XX000", "scram: cannot enter SASL state: "+err.Error())
|
|
return errAuthRejected
|
|
}
|
|
|
|
// Step 1 — receive client-first.
|
|
msg, err := s.be.Receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
init, ok := msg.(*pgproto3.SASLInitialResponse)
|
|
if !ok {
|
|
s.sendError("28000", fmt.Sprintf("scram: expected SASLInitialResponse, got %T", msg))
|
|
return errAuthRejected
|
|
}
|
|
if init.AuthMechanism != "SCRAM-SHA-256" {
|
|
s.sendError("28000",
|
|
"scram: unsupported mechanism "+init.AuthMechanism+" (offer SCRAM-SHA-256)")
|
|
return errAuthRejected
|
|
}
|
|
clientFirst := string(init.Data)
|
|
clientFirstBare, okBare := scramClientFirstBare(clientFirst)
|
|
if !okBare {
|
|
s.sendError("28000", "scram: malformed client-first message")
|
|
return errAuthRejected
|
|
}
|
|
attrs := scramParseAttrs(clientFirstBare)
|
|
clientNonce := attrs["r"]
|
|
if clientNonce == "" {
|
|
s.sendError("28000", "scram: client-first missing r= nonce")
|
|
return errAuthRejected
|
|
}
|
|
|
|
// Step 2 — generate server-nonce + salt + iterations, send
|
|
// server-first. The combined nonce is client-nonce ++
|
|
// server-nonce; the client must echo it back verbatim in
|
|
// client-final or we reject.
|
|
var serverNonceRaw [18]byte
|
|
if _, err := rand.Read(serverNonceRaw[:]); err != nil {
|
|
s.sendError("XX000", "scram: rng failure")
|
|
return err
|
|
}
|
|
serverNonce := base64.StdEncoding.EncodeToString(serverNonceRaw[:])
|
|
var salt [16]byte
|
|
if _, err := rand.Read(salt[:]); err != nil {
|
|
s.sendError("XX000", "scram: rng failure")
|
|
return err
|
|
}
|
|
combinedNonce := clientNonce + serverNonce
|
|
serverFirst := scramServerFirst(combinedNonce, salt[:], scramIterations)
|
|
s.send(&pgproto3.AuthenticationSASLContinue{Data: []byte(serverFirst)})
|
|
// Switch the decoder into SASLContinue mode so the next 'p'
|
|
// frame is parsed as SASLResponse (client-final), not as a
|
|
// fresh PasswordMessage.
|
|
if err := s.be.SetAuthType(pgproto3.AuthTypeSASLContinue); err != nil {
|
|
s.sendError("XX000", "scram: cannot enter SASLContinue state: "+err.Error())
|
|
return errAuthRejected
|
|
}
|
|
|
|
// Step 3 — receive client-final.
|
|
msg, err = s.be.Receive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
final, ok := msg.(*pgproto3.SASLResponse)
|
|
if !ok {
|
|
s.sendError("28000", fmt.Sprintf("scram: expected SASLResponse, got %T", msg))
|
|
return errAuthRejected
|
|
}
|
|
clientFinal := string(final.Data)
|
|
// Split off the proof — everything before ",p=" is the
|
|
// client-final-without-proof that goes into the AuthMessage.
|
|
pIdx := strings.LastIndex(clientFinal, ",p=")
|
|
if pIdx < 0 {
|
|
s.sendError("28000", "scram: client-final missing p= proof")
|
|
return errAuthRejected
|
|
}
|
|
clientFinalNoProof := clientFinal[:pIdx]
|
|
clientProofB64 := clientFinal[pIdx+3:]
|
|
finalAttrs := scramParseAttrs(clientFinalNoProof)
|
|
if finalAttrs["r"] != combinedNonce {
|
|
s.sendError("28P01", "scram: nonce mismatch")
|
|
return errAuthRejected
|
|
}
|
|
|
|
// Step 4 — resolve role + compute the same SaltedPassword the
|
|
// client did, then verify the ClientProof against StoredKey.
|
|
// Constant-time verify hides whether the role exists vs. the
|
|
// password was wrong; PG itself happily distinguishes the two,
|
|
// but we err on the safe side.
|
|
r := lookupRole(s.user)
|
|
authMsg := scramAuthMessage(clientFirstBare, serverFirst, clientFinalNoProof)
|
|
var saltedPassword []byte
|
|
if r != nil {
|
|
saltedPassword = scramSaltedPassword(r.PasswordPlain, salt[:], scramIterations)
|
|
} else {
|
|
// Compute a dummy salted password so we still spend the
|
|
// PBKDF2 cycles — keeps the timing roughly constant
|
|
// regardless of role existence.
|
|
saltedPassword = scramSaltedPassword("", salt[:], scramIterations)
|
|
}
|
|
if r == nil || !scramVerifyClientProof(saltedPassword, authMsg, clientProofB64) {
|
|
s.sendError("28P01",
|
|
fmt.Sprintf("SCRAM authentication failed for user %q", s.user))
|
|
return errAuthRejected
|
|
}
|
|
|
|
// Step 5 — send server signature then AuthenticationOk.
|
|
serverSig := scramServerSignature(saltedPassword, authMsg)
|
|
s.send(&pgproto3.AuthenticationSASLFinal{Data: []byte("v=" + serverSig)})
|
|
s.send(&pgproto3.AuthenticationOk{})
|
|
return nil
|
|
}
|
|
|
|
// md5Challenge reproduces libpq's md5 client computation so we
|
|
// can compare against the value the client sent.
|
|
func md5Challenge(password, user string, salt []byte) string {
|
|
inner := md5.Sum([]byte(password + user))
|
|
innerHex := hex.EncodeToString(inner[:])
|
|
outer := md5.Sum(append([]byte(innerHex), salt...))
|
|
return "md5" + hex.EncodeToString(outer[:])
|
|
}
|
|
|
|
// sentinelError lets the run() loop bail out without typing the
|
|
// "fmt.Errorf" boilerplate at every call site.
|
|
type sentinelError string
|
|
|
|
func (e sentinelError) Error() string { return string(e) }
|
|
|
|
const errAuthRejected = sentinelError("pgserver: authentication rejected")
|