Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

correct handling of role cert key path when service key filename is user-specified #2213

Merged
merged 1 commit into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions libs/go/sia/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ func GetPrevRoleCertDates(certFile string) (*rdl.Timestamp, *rdl.Timestamp, erro
}

func RoleKey(rotateKey bool, svcKey string) (*rsa.PrivateKey, error) {
if rotateKey == true {
if rotateKey {
return util.GenerateKeyPair(2048)
} else {
return util.PrivateKeyFromFile(svcKey)
}
return util.PrivateKeyFromFile(svcKey)
}

func GetRoleCertificates(ztsUrl string, opts *options.Options) (int, int) {
Expand Down Expand Up @@ -448,14 +449,17 @@ func SaveRoleCertKey(key, cert []byte, role options.Role, opts *options.Options)
if role.Filename != "" {
certPrefix = strings.TrimSuffix(role.Filename, ".cert.pem")
}
svcKeyFile := ""
keyPrefix := fmt.Sprintf("%s.%s", opts.Domain, role.Service)
if opts.GenerateRoleKey == true {
if opts.GenerateRoleKey {
keyPrefix = role.Name
if role.Filename != "" {
keyPrefix = strings.TrimSuffix(role.Filename, ".cert.pem")
}
} else {
svcKeyFile = util.GetSvcKeyFileName(opts.KeyDir, role.SvcKeyFilename, opts.Domain, role.Service)
}
return util.SaveRoleCertKey(key, cert, role.Filename, keyPrefix, certPrefix, role.Uid, role.Gid, role.FileMode, opts.GenerateRoleKey, opts.RotateKey, opts.KeyDir, opts.CertDir, opts.BackupDir, opts.FileDirectUpdate)
return util.SaveRoleCertKey(key, cert, svcKeyFile, role.Filename, keyPrefix, certPrefix, role.Uid, role.Gid, role.FileMode, opts.GenerateRoleKey, opts.RotateKey, opts.KeyDir, opts.CertDir, opts.BackupDir, opts.FileDirectUpdate)
}

func restartSshdService() error {
Expand Down
12 changes: 8 additions & 4 deletions libs/go/sia/aws/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ func GetPrevRoleCertDates(certFile string) (*rdl.Timestamp, *rdl.Timestamp, erro
}

func RoleKey(rotateKey bool, svcKey string) (*rsa.PrivateKey, error) {
if rotateKey == true {
if rotateKey {
return util.GenerateKeyPair(2048)
} else {
return util.PrivateKeyFromFile(svcKey)
}
return util.PrivateKeyFromFile(svcKey)
}

func GetRoleCertificates(ztsUrl string, opts *options.Options) (int, int) {
Expand Down Expand Up @@ -438,14 +439,17 @@ func SaveRoleCertKey(key, cert []byte, role options.Role, opts *options.Options)
if role.Filename != "" {
certPrefix = strings.TrimSuffix(role.Filename, ".cert.pem")
}
svcKeyFile := ""
keyPrefix := fmt.Sprintf("%s.%s", opts.Domain, role.Service)
if opts.GenerateRoleKey == true {
if opts.GenerateRoleKey {
keyPrefix = role.Name
if role.Filename != "" {
keyPrefix = strings.TrimSuffix(role.Filename, ".cert.pem")
}
} else {
svcKeyFile = util.GetSvcKeyFileName(opts.KeyDir, role.SvcKeyFilename, opts.Domain, role.Service)
}
return util.SaveRoleCertKey(key, cert, role.Filename, keyPrefix, certPrefix, role.Uid, role.Gid, role.FileMode, opts.GenerateRoleKey, opts.RotateKey, opts.KeyDir, opts.CertDir, opts.BackupDir, opts.FileDirectUpdate)
return util.SaveRoleCertKey(key, cert, svcKeyFile, role.Filename, keyPrefix, certPrefix, role.Uid, role.Gid, role.FileMode, opts.GenerateRoleKey, opts.RotateKey, opts.KeyDir, opts.CertDir, opts.BackupDir, opts.FileDirectUpdate)
}

func restartSshdService() error {
Expand Down
21 changes: 14 additions & 7 deletions libs/go/sia/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -730,16 +730,23 @@ func ParseEnvFloatFlag(varName string, defaultValue float64) float64 {
return value
}

func getCertKeyFileName(file, keyDir, certDir, keyPrefix, certPrefix string) (string, string) {
if file != "" && file[0] == '/' {
return file, fmt.Sprintf("%s/%s.key.pem", keyDir, keyPrefix)
func getCertKeyFileName(keyFile, certFile, keyDir, certDir, keyPrefix, certPrefix string) (string, string) {
if keyFile == "" {
keyFile = fmt.Sprintf("%s/%s.key.pem", keyDir, keyPrefix)
}
if certFile != "" {
if certFile[0] == '/' {
return keyFile, certFile
} else {
return keyFile, fmt.Sprintf("%s/%s", certDir, certFile)
}
} else {
return fmt.Sprintf("%s/%s.cert.pem", certDir, certPrefix), fmt.Sprintf("%s/%s.key.pem", keyDir, keyPrefix)
return keyFile, fmt.Sprintf("%s/%s.cert.pem", certDir, certPrefix)
}
}

func SaveRoleCertKey(key, cert []byte, file, keyPrefix, certPrefix string, uid, gid, fileMode int, createKey, rotateKey bool, keyDir, certDir, backupDir string, fileDirectUpdate bool) error {
certFile, keyFile := getCertKeyFileName(file, keyDir, certDir, keyPrefix, certPrefix)
func SaveRoleCertKey(key, cert []byte, svcKeyFile, roleCertFile, keyPrefix, certPrefix string, uid, gid, fileMode int, createKey, rotateKey bool, keyDir, certDir, backupDir string, fileDirectUpdate bool) error {
keyFile, certFile := getCertKeyFileName(svcKeyFile, roleCertFile, keyDir, certDir, keyPrefix, certPrefix)
return SaveCertKey(key, cert, keyFile, certFile, keyPrefix, certPrefix, uid, gid, fileMode, createKey, rotateKey, backupDir, fileDirectUpdate)
}

Expand Down Expand Up @@ -790,7 +797,7 @@ func SaveCertKey(key, cert []byte, keyFile, certFile, keyPrefix, certPrefix stri
return err
}
} else if FileExists(keyFile) {
log.Printf("Updating existing key file %s", keyFile)
log.Printf("Updating existing key file %s ownership only", keyFile)
UpdateKeyOwnership(keyFile, uid, gid, os.FileMode(fileMode), fileDirectUpdate)
}
log.Printf("Updating the cert file %s", certFile)
Expand Down
91 changes: 91 additions & 0 deletions libs/go/sia/util/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1263,3 +1263,94 @@ func TestParseScriptArguments(t *testing.T) {
})
}
}

func TestGetCertKeyFileName(t *testing.T) {

tests := []struct {
name string
keyFile string
certFile string
keyDir string
certDir string
keyPrefix string
certPrefix string
scriptPath string
resultKey string
resultCert string
}{
{
name: "non-empty full-path key/cert files",
keyFile: "/var/athenz/key.pem",
certFile: "/var/athenz/cert.pem",
keyDir: "not-used",
certDir: "not-used",
keyPrefix: "not-used",
certPrefix: "not-used",
resultKey: "/var/athenz/key.pem",
resultCert: "/var/athenz/cert.pem",
},
{
name: "empty key file and full-path cert file",
keyFile: "",
certFile: "/var/athenz/cert.pem",
keyDir: "/var/test1",
certDir: "not-used",
keyPrefix: "key-prefix",
certPrefix: "not-used",
resultKey: "/var/test1/key-prefix.key.pem",
resultCert: "/var/athenz/cert.pem",
},
{
name: "non-empty full-path key file and cert file",
keyFile: "/var/athenz/key.pem",
certFile: "cert-file",
keyDir: "not-used",
certDir: "/var/test2",
keyPrefix: "not-used",
certPrefix: "not-used",
resultKey: "/var/athenz/key.pem",
resultCert: "/var/test2/cert-file",
},
{
name: "empty key file and full-path cert file",
keyFile: "",
certFile: "cert-file",
keyDir: "/var/test1",
certDir: "/var/test2",
keyPrefix: "key-prefix",
certPrefix: "not-used",
resultKey: "/var/test1/key-prefix.key.pem",
resultCert: "/var/test2/cert-file",
},
{
name: "empty key and cert files",
keyFile: "",
certFile: "",
keyDir: "/var/test3",
certDir: "/var/test4",
keyPrefix: "key-prefix",
certPrefix: "cert-prefix",
resultKey: "/var/test3/key-prefix.key.pem",
resultCert: "/var/test4/cert-prefix.cert.pem",
},
{
name: "non-empty full-path key file and empty cert files",
keyFile: "/var/athenz/key.pem",
certFile: "",
keyDir: "not-used",
certDir: "/var/test4",
keyPrefix: "not-used",
certPrefix: "cert-prefix",
resultKey: "/var/athenz/key.pem",
resultCert: "/var/test4/cert-prefix.cert.pem",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyFile, certFile := getCertKeyFileName(tt.keyFile, tt.certFile, tt.keyDir, tt.certDir, tt.keyPrefix, tt.certPrefix)
assert.Equal(t, tt.resultKey, keyFile)
assert.Equal(t, tt.resultCert, certFile)
})
}
}