Skip to content

Commit

Permalink
curve25519: use crypto/ecdh on Go 1.20
Browse files Browse the repository at this point in the history
For golang/go#52221

Change-Id: I27e867d4cc89cd52c8d510f0dbab4e89b7cd4763
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/451115
Auto-Submit: Filippo Valsorda <[email protected]>
Reviewed-by: Cherry Mui <[email protected]>
TryBot-Result: Gopher Robot <[email protected]>
Run-TryBot: Filippo Valsorda <[email protected]>
Reviewed-by: Roland Shoemaker <[email protected]>
  • Loading branch information
FiloSottile authored and gopherbot committed Mar 13, 2023
1 parent c6a20f9 commit 9cd0187
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 107 deletions.
99 changes: 6 additions & 93 deletions curve25519/curve25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,71 +5,18 @@
// Package curve25519 provides an implementation of the X25519 function, which
// performs scalar multiplication on the elliptic curve known as Curve25519.
// See RFC 7748.
//
// Starting in Go 1.20, this package is a wrapper for the X25519 implementation
// in the crypto/ecdh package.
package curve25519 // import "golang.org/x/crypto/curve25519"

import (
"crypto/subtle"
"errors"
"strconv"

"golang.org/x/crypto/curve25519/internal/field"
)

// ScalarMult sets dst to the product scalar * point.
//
// Deprecated: when provided a low-order point, ScalarMult will set dst to all
// zeroes, irrespective of the scalar. Instead, use the X25519 function, which
// will return an error.
func ScalarMult(dst, scalar, point *[32]byte) {
var e [32]byte

copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64

var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()

swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)

tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)

z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}

x2.Swap(&x3, swap)
z2.Swap(&z3, swap)

z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
scalarMult(dst, scalar, point)
}

// ScalarBaseMult sets dst to the product scalar * base where base is the
Expand All @@ -78,7 +25,7 @@ func ScalarMult(dst, scalar, point *[32]byte) {
// It is recommended to use the X25519 function with Basepoint instead, as
// copying into fixed size arrays can lead to unexpected bugs.
func ScalarBaseMult(dst, scalar *[32]byte) {
ScalarMult(dst, scalar, &basePoint)
scalarBaseMult(dst, scalar)
}

const (
Expand All @@ -91,21 +38,10 @@ const (
// Basepoint is the canonical Curve25519 generator.
var Basepoint []byte

var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
var basePoint = [32]byte{9}

func init() { Basepoint = basePoint[:] }

func checkBasepoint() {
if subtle.ConstantTimeCompare(Basepoint, []byte{
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}) != 1 {
panic("curve25519: global Basepoint value was modified")
}
}

// X25519 returns the result of the scalar multiplication (scalar * point),
// according to RFC 7748, Section 5. scalar, point and the return value are
// slices of 32 bytes.
Expand All @@ -121,26 +57,3 @@ func X25519(scalar, point []byte) ([]byte, error) {
var dst [32]byte
return x25519(&dst, scalar, point)
}

func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte
if l := len(scalar); l != 32 {
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
}
if l := len(point); l != 32 {
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
}
copy(in[:], scalar)
if &point[0] == &Basepoint[0] {
checkBasepoint()
ScalarBaseMult(dst, &in)
} else {
var base, zero [32]byte
copy(base[:], point)
ScalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, errors.New("bad input point: low order point")
}
}
return dst[:], nil
}
105 changes: 105 additions & 0 deletions curve25519/curve25519_compat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build !go1.20

package curve25519

import (
"crypto/subtle"
"errors"
"strconv"

"golang.org/x/crypto/curve25519/internal/field"
)

func scalarMult(dst, scalar, point *[32]byte) {
var e [32]byte

copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64

var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()

swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)

tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)

z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}

x2.Swap(&x3, swap)
z2.Swap(&z3, swap)

z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
}

func scalarBaseMult(dst, scalar *[32]byte) {
checkBasepoint()
scalarMult(dst, scalar, &basePoint)
}

func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
var in [32]byte
if l := len(scalar); l != 32 {
return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
}
if l := len(point); l != 32 {
return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
}
copy(in[:], scalar)
if &point[0] == &Basepoint[0] {
scalarBaseMult(dst, &in)
} else {
var base, zero [32]byte
copy(base[:], point)
scalarMult(dst, &in, &base)
if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
return nil, errors.New("bad input point: low order point")
}
}
return dst[:], nil
}

func checkBasepoint() {
if subtle.ConstantTimeCompare(Basepoint, []byte{
0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}) != 1 {
panic("curve25519: global Basepoint value was modified")
}
}
46 changes: 46 additions & 0 deletions curve25519/curve25519_go120.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.20

package curve25519

import "crypto/ecdh"

func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
curve := ecdh.X25519()
pub, err := curve.NewPublicKey(point)
if err != nil {
return nil, err
}
priv, err := curve.NewPrivateKey(scalar)
if err != nil {
return nil, err
}
out, err := priv.ECDH(pub)
if err != nil {
return nil, err
}
copy(dst[:], out)
return dst[:], nil
}

func scalarMult(dst, scalar, point *[32]byte) {
if _, err := x25519(dst, scalar[:], point[:]); err != nil {
// The only error condition for x25519 when the inputs are 32 bytes long
// is if the output would have been the all-zero value.
for i := range dst {
dst[i] = 0
}
}
}

func scalarBaseMult(dst, scalar *[32]byte) {
curve := ecdh.X25519()
priv, err := curve.NewPrivateKey(scalar[:])
if err != nil {
panic("curve25519: internal error: scalarBaseMult was not 32 bytes")
}
copy(dst[:], priv.PublicKey().Bytes())
}
28 changes: 15 additions & 13 deletions curve25519/curve25519_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package curve25519
package curve25519_test

import (
"bytes"
"crypto/rand"
"encoding/hex"
"testing"

"golang.org/x/crypto/curve25519"
)

const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a"
Expand All @@ -19,7 +21,7 @@ func TestX25519Basepoint(t *testing.T) {

for i := 0; i < 200; i++ {
var err error
x, err = X25519(x, Basepoint)
x, err = curve25519.X25519(x, curve25519.Basepoint)
if err != nil {
t.Fatal(err)
}
Expand All @@ -32,12 +34,12 @@ func TestX25519Basepoint(t *testing.T) {
}

func TestLowOrderPoints(t *testing.T) {
scalar := make([]byte, ScalarSize)
scalar := make([]byte, curve25519.ScalarSize)
if _, err := rand.Read(scalar); err != nil {
t.Fatal(err)
}
for i, p := range lowOrderPoints {
out, err := X25519(scalar, p)
out, err := curve25519.X25519(scalar, p)
if err == nil {
t.Errorf("%d: expected error, got nil", i)
}
Expand All @@ -48,10 +50,10 @@ func TestLowOrderPoints(t *testing.T) {
}

func TestTestVectors(t *testing.T) {
t.Run("Legacy", func(t *testing.T) { testTestVectors(t, ScalarMult) })
t.Run("Legacy", func(t *testing.T) { testTestVectors(t, curve25519.ScalarMult) })
t.Run("X25519", func(t *testing.T) {
testTestVectors(t, func(dst, scalar, point *[32]byte) {
out, err := X25519(scalar[:], point[:])
out, err := curve25519.X25519(scalar[:], point[:])
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -88,10 +90,10 @@ func TestHighBitIgnored(t *testing.T) {
var hi0, hi1 [32]byte

u[31] &= 0x7f
ScalarMult(&hi0, &s, &u)
curve25519.ScalarMult(&hi0, &s, &u)

u[31] |= 0x80
ScalarMult(&hi1, &s, &u)
curve25519.ScalarMult(&hi1, &s, &u)

if !bytes.Equal(hi0[:], hi1[:]) {
t.Errorf("high bit of group point should not affect result")
Expand All @@ -101,14 +103,14 @@ func TestHighBitIgnored(t *testing.T) {
var benchmarkSink byte

func BenchmarkX25519Basepoint(b *testing.B) {
scalar := make([]byte, ScalarSize)
scalar := make([]byte, curve25519.ScalarSize)
if _, err := rand.Read(scalar); err != nil {
b.Fatal(err)
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
out, err := X25519(scalar, Basepoint)
out, err := curve25519.X25519(scalar, curve25519.Basepoint)
if err != nil {
b.Fatal(err)
}
Expand All @@ -117,11 +119,11 @@ func BenchmarkX25519Basepoint(b *testing.B) {
}

func BenchmarkX25519(b *testing.B) {
scalar := make([]byte, ScalarSize)
scalar := make([]byte, curve25519.ScalarSize)
if _, err := rand.Read(scalar); err != nil {
b.Fatal(err)
}
point, err := X25519(scalar, Basepoint)
point, err := curve25519.X25519(scalar, curve25519.Basepoint)
if err != nil {
b.Fatal(err)
}
Expand All @@ -131,7 +133,7 @@ func BenchmarkX25519(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
out, err := X25519(scalar, point)
out, err := curve25519.X25519(scalar, point)
if err != nil {
b.Fatal(err)
}
Expand Down
Loading

0 comments on commit 9cd0187

Please sign in to comment.