// Copyright (c) 2026 Charles KWON OhJun (charleskwonohjun@gmail.com) // All rights reserved. // scram.go — SCRAM-SHA-256 (RFC 5802) primitives for the pgserver // auth path. PostgreSQL 14+ default; libpq, pgx, and JDBC all // prefer it over MD5 when offered. // // Storage model: roles in the registry hold the plaintext password // (same as the MD5 path). For each authentication we generate a // fresh salt + iteration count and derive SaltedPassword on the // fly. This is functionally equivalent to PG's "scram-stored" // secrets and keeps the code-path identical for both clients — // the wire output matches RFC 5802 byte for byte. package pgserver import ( "crypto/hmac" "crypto/pbkdf2" "crypto/sha256" "encoding/base64" "fmt" "strconv" "strings" ) // scramIterations is the PBKDF2 iteration count we advertise to // clients. PG 15+ uses 4096 (the SCRAM minimum) by default; // matching that keeps handshake latency negligible while still // satisfying RFC 5802. const scramIterations = 4096 // scramSaltedPassword runs PBKDF2-HMAC-SHA256 to derive the // 32-byte SaltedPassword from a UTF-8 password + salt. Both the // client and the server compute this independently from the same // inputs; if they disagree the resulting proofs won't match. func scramSaltedPassword(password string, salt []byte, iter int) []byte { key, err := pbkdf2.Key(sha256.New, password, salt, iter, sha256.Size) if err != nil { // PBKDF2 only errors for nonsense params (negative iter, // zero-length key). Our call site uses fixed constants so // this branch is unreachable; surface as a panic so a // future caller doesn't silently get a nil slice. panic(fmt.Sprintf("scram: pbkdf2 failure: %v", err)) } return key } // scramHMAC computes HMAC-SHA256(key, data) — RFC 5802 § 2.2. func scramHMAC(key, data []byte) []byte { m := hmac.New(sha256.New, key) m.Write(data) return m.Sum(nil) } // scramH computes SHA-256(input) — RFC 5802 § 2.2 H(). func scramH(input []byte) []byte { sum := sha256.Sum256(input) return sum[:] } // scramXOR returns a XOR b. Both slices must be the same length; // the caller controls them (HMAC outputs are fixed 32B). func scramXOR(a, b []byte) []byte { out := make([]byte, len(a)) for i := range a { out[i] = a[i] ^ b[i] } return out } // scramClientFirstBare extracts the bare GS2-header-free portion // of the client-first message. Spec form: // // client-first-message = gs2-header client-first-message-bare // gs2-header = gs2-cbind-flag "," [ authzid ] "," // client-first-bare = "n=" saslname "," "r=" c-nonce // // For PG the channel-binding flag is always "n" (none), so the // header is the two-byte "n," followed by an empty authzid then // another ",". We strip those three leading bytes. func scramClientFirstBare(clientFirst string) (bare string, ok bool) { if !strings.HasPrefix(clientFirst, "n,") && !strings.HasPrefix(clientFirst, "y,") { return "", false } idx := strings.Index(clientFirst, ",") if idx < 0 { return "", false } rest := clientFirst[idx+1:] idx2 := strings.Index(rest, ",") if idx2 < 0 { return "", false } return rest[idx2+1:], true } // scramParseAttrs splits "k=v,k=v,..." into a map. Spec attribute // values can't contain "=" or "," so a plain split is safe. func scramParseAttrs(s string) map[string]string { out := map[string]string{} for _, kv := range strings.Split(s, ",") { i := strings.IndexByte(kv, '=') if i <= 0 { continue } out[kv[:i]] = kv[i+1:] } return out } // scramServerFirst builds the server-first-message: // // "r=,s=,i=" // // Returned as a string (becomes Data on AuthenticationSASLContinue). func scramServerFirst(combinedNonce string, salt []byte, iter int) string { return "r=" + combinedNonce + ",s=" + base64.StdEncoding.EncodeToString(salt) + ",i=" + strconv.Itoa(iter) } // scramAuthMessage builds the SCRAM AuthMessage: // // AuthMessage = client-first-bare + "," + server-first + // "," + client-final-without-proof // // Both sides compute this identically; it's the input both the // client proof and the server signature HMAC over. func scramAuthMessage(clientFirstBare, serverFirst, clientFinalNoProof string) []byte { return []byte(clientFirstBare + "," + serverFirst + "," + clientFinalNoProof) } // scramVerifyClientProof reconstructs ClientKey from the client's // proof and confirms that H(ClientKey) == StoredKey. Returns true // iff the proof is valid. // // ClientSignature = HMAC(StoredKey, AuthMessage) // ClientKey' = ClientProof XOR ClientSignature // verify H(ClientKey') == StoredKey func scramVerifyClientProof(saltedPassword []byte, authMessage []byte, clientProofB64 string) bool { clientProof, err := base64.StdEncoding.DecodeString(clientProofB64) if err != nil || len(clientProof) != sha256.Size { return false } clientKey := scramHMAC(saltedPassword, []byte("Client Key")) storedKey := scramH(clientKey) clientSig := scramHMAC(storedKey, authMessage) recovered := scramXOR(clientProof, clientSig) return hmac.Equal(scramH(recovered), storedKey) } // scramServerSignature computes the server's proof to be returned // in AuthenticationSASLFinal: // // ServerKey = HMAC(SaltedPassword, "Server Key") // ServerSignature = HMAC(ServerKey, AuthMessage) // // Returns base64 — wire format expects "v=" + base64(signature). func scramServerSignature(saltedPassword []byte, authMessage []byte) string { serverKey := scramHMAC(saltedPassword, []byte("Server Key")) sig := scramHMAC(serverKey, authMessage) return base64.StdEncoding.EncodeToString(sig) }