Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance secure credential registeration #1761

Merged
merged 2 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions init/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import requests
import yaml
from tqdm import tqdm
from colorama import Fore, init
from tabulate import tabulate
from colorama import Fore, Style, init
from getpass import getpass
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA256
Expand Down Expand Up @@ -39,7 +40,7 @@
ENC_FILE_PATH = os.path.join(CRED_PATH, CRED_FILE_NAME_ENC)
KEY_FILE = os.path.join(CRED_PATH, ".tmp_enc_key")

expected_completion_time_seconds = 400
expected_completion_time_seconds = 600

# Check for credential path
if not os.path.exists(CRED_PATH):
Expand Down Expand Up @@ -146,7 +147,6 @@ def encrypt_credential_value_with_publickey(public_key_pem, credentials):

return encrypted_credentials, encrypted_aes_key

# Function to register credentials using encrypted values and encrypted AES key
def register_credential(provider, credentials):
try:
if all(credentials.values()):
Expand All @@ -168,24 +168,71 @@ def register_credential(provider, credentials):
"credentialKeyValueList": [{"key": k, "value": v} for k, v in encrypted_credentials.items()],
"providerName": provider,
"publicKeyTokenId": public_key_token_id,
"encryptedAesKey": encrypted_aes_key
"encryptedClientAesKeyByPublicKey": encrypted_aes_key
}

# Step 4: Register the encrypted credentials
response = requests.post(f"http://{TUMBLEBUG_SERVER}/tumblebug/credential", json=credential_payload, headers=HEADERS)
return provider, response.json(), Fore.GREEN

if response.status_code == 200:
# Extract relevant data for message
result_data = response.json()
message = print_credential_info(result_data)
return provider, message, Fore.GREEN
else:
message = response.json().get('message', response.text)
return provider, message, Fore.RED
else:
return provider, "Incomplete credential data, Skip", Fore.RED
message = "Incomplete credential data, Skip"
return provider, message, Fore.RED
except Exception as e:
return provider, f"Error registering credentials: {str(e)}", Fore.RED
message = "Error registering credentials: " + str(e)
return provider, message, Fore.RED


# Function to print formatted credential information
def print_credential_info(response):
if 'credentialName' in response and 'credentialHolder' in response:
# Print credential name and holder in bold
print(Fore.YELLOW + f"\n{response['credentialName'].upper()} (holder: {response['credentialHolder']})" + Style.RESET_ALL)

if 'allConnections' in response and 'connectionconfig' in response['allConnections']:
# Print the explanation line in yellow
print(Style.BRIGHT + "Registered Connections" + Fore.GREEN + " [verified]" + Fore.MAGENTA + "[region representative]" + Style.RESET_ALL)

# Prepare table headers and rows
headers = ["Config Name", "Assigned Region", "Assigned Zone"]
table_rows = []
for conn in response['allConnections']['connectionconfig']:
if conn['providerName'] == response['providerName']:
# Config name with green color if verified
config_name_display = Fore.GREEN + conn['configName'] + Style.RESET_ALL if conn['verified'] else conn['configName']

# Assigned Zone with pink color if region representative
assigned_zone_display = Fore.MAGENTA + conn['regionZoneInfo']['assignedZone'] + Style.RESET_ALL if conn['regionRepresentative'] else conn['regionZoneInfo']['assignedZone']

# Add row to the table
table_rows.append([
config_name_display,
conn['regionZoneInfo']['assignedRegion'],
assigned_zone_display
])

# Print table
print(tabulate(table_rows, headers, tablefmt="grid"))


# Register credentials to TumblebugServer using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=5) as executor:
future_to_provider = {executor.submit(register_credential, provider, credentials): provider for provider, credentials in cred_data.items()}
for future in as_completed(future_to_provider):
provider, response, color = future.result()
print(color + f"- {provider}: {response}")
provider, message, color = future.result()
if message is None:
message = "" # Handle NoneType message
else:
print("")
print(color + f"- {provider.upper()}: {message}")
print_credential_info(message)

print(Fore.YELLOW + "\nLoading common Specs and Images...")
print(Fore.RESET)
Expand Down
3 changes: 3 additions & 0 deletions init/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ colorama==0.4.6

# Used for cryptographic functions like RSA encryption
pycryptodome==3.19.1

# Used for displaying data in table format
tabulate==0.8.9
3 changes: 3 additions & 0 deletions src/api/rest/docs/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -9060,6 +9060,9 @@ const docTemplate = `{
"common.CredentialInfo": {
"type": "object",
"properties": {
"allConnections": {
"$ref": "#/definitions/common.ConnConfigList"
},
"credentialHolder": {
"type": "string"
},
Expand Down
3 changes: 3 additions & 0 deletions src/api/rest/docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -9053,6 +9053,9 @@
"common.CredentialInfo": {
"type": "object",
"properties": {
"allConnections": {
"$ref": "#/definitions/common.ConnConfigList"
},
"credentialHolder": {
"type": "string"
},
Expand Down
2 changes: 2 additions & 0 deletions src/api/rest/docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ definitions:
type: object
common.CredentialInfo:
properties:
allConnections:
$ref: '#/definitions/common.ConnConfigList'
credentialHolder:
type: string
credentialName:
Expand Down
18 changes: 12 additions & 6 deletions src/core/common/utility.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,11 @@ type CredentialReq struct {

// CredentialInfo is struct for containing a struct for credential info
type CredentialInfo struct {
CredentialName string `json:"credentialName"`
CredentialHolder string `json:"credentialHolder"`
ProviderName string `json:"providerName"`
KeyValueInfoList []KeyValue `json:"keyValueInfoList"`
CredentialName string `json:"credentialName"`
CredentialHolder string `json:"credentialHolder"`
ProviderName string `json:"providerName"`
KeyValueInfoList []KeyValue `json:"keyValueInfoList"`
AllConnections ConnConfigList `json:"allConnections"`
}

// GetConnConfig is func to get connection config
Expand Down Expand Up @@ -673,8 +674,7 @@ func RegisterCredential(req CredentialReq) (CredentialInfo, error) {
return CredentialInfo{}, fmt.Errorf("private key not found for token ID: %s", req.PublicKeyTokenId)
}

fmt.Printf("Private key exists: %+v\n", privateKey)
PrintJsonPretty(req)
// PrintJsonPretty(req)

// Decrypt the AES key
encryptedAesKey, err := base64.StdEncoding.DecodeString(req.EncryptedClientAesKeyByPublicKey)
Expand Down Expand Up @@ -963,6 +963,12 @@ func RegisterCredential(req CredentialReq) (CredentialInfo, error) {
}
}

callResult.AllConnections, err = GetConnConfigList(req.CredentialHolder, false, false)
if err != nil {
log.Error().Err(err).Msg("")
return callResult, err
}

return callResult, nil
}

Expand Down
Loading