diff --git a/.changelog/31755.txt b/.changelog/31755.txt new file mode 100644 index 00000000000..4bcd501601b --- /dev/null +++ b/.changelog/31755.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +resource/aws_sagemaker_model: Add `container.model_package_name` and `primary_container.model_package_name` arguments +``` \ No newline at end of file diff --git a/internal/service/sagemaker/model.go b/internal/service/sagemaker/model.go index ab8c9ff48ca..00307ce3cf9 100644 --- a/internal/service/sagemaker/model.go +++ b/internal/service/sagemaker/model.go @@ -59,7 +59,7 @@ func ResourceModel() *schema.Resource { }, "image": { Type: schema.TypeString, - Required: true, + Optional: true, ForceNew: true, ValidateFunc: validImage, }, @@ -106,6 +106,12 @@ func ResourceModel() *schema.Resource { ForceNew: true, ValidateFunc: validModelDataURL, }, + "model_package_name": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: verify.ValidARN, + }, }, }, }, @@ -164,7 +170,7 @@ func ResourceModel() *schema.Resource { }, "image": { Type: schema.TypeString, - Required: true, + Optional: true, ForceNew: true, ValidateFunc: validImage, }, @@ -211,6 +217,12 @@ func ResourceModel() *schema.Resource { ForceNew: true, ValidateFunc: validModelDataURL, }, + "model_package_name": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: verify.ValidARN, + }, }, }, }, @@ -404,8 +416,10 @@ func resourceModelDelete(ctx context.Context, d *schema.ResourceData, meta inter } func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition { - container := sagemaker.ContainerDefinition{ - Image: aws.String(m["image"].(string)), + container := sagemaker.ContainerDefinition{} + + if v, ok := m["image"]; ok && v.(string) != "" { + container.Image = aws.String(v.(string)) } if v, ok := m["mode"]; ok && v.(string) != "" { @@ -418,6 +432,9 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition { if v, ok := m["model_data_url"]; ok && v.(string) != "" { container.ModelDataUrl = aws.String(v.(string)) } + if v, ok := m["model_package_name"]; ok && v.(string) != "" { + container.ModelPackageName = aws.String(v.(string)) + } if v, ok := m["environment"].(map[string]interface{}); ok && len(v) > 0 { container.Environment = flex.ExpandStringMap(v) } @@ -478,7 +495,9 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} { cfg := make(map[string]interface{}) - cfg["image"] = aws.StringValue(container.Image) + if container.Image != nil { + cfg["image"] = aws.StringValue(container.Image) + } if container.Mode != nil { cfg["mode"] = aws.StringValue(container.Mode) @@ -490,6 +509,9 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} { if container.ModelDataUrl != nil { cfg["model_data_url"] = aws.StringValue(container.ModelDataUrl) } + if container.ModelPackageName != nil { + cfg["model_package_name"] = aws.StringValue(container.ModelPackageName) + } if container.Environment != nil { cfg["environment"] = aws.StringValueMap(container.Environment) } diff --git a/internal/service/sagemaker/model_test.go b/internal/service/sagemaker/model_test.go index afcc3ccf031..6c57b9cca90 100644 --- a/internal/service/sagemaker/model_test.go +++ b/internal/service/sagemaker/model_test.go @@ -262,6 +262,33 @@ func TestAccSageMakerModel_primaryContainerModeSingle(t *testing.T) { }) } +func TestAccSageMakerModel_primaryContainerModelPackageName(t *testing.T) { + ctx := acctest.Context(t) + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_sagemaker_model.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t) }, + ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckModelDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccModelConfig_primaryContainerPackageName(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckModelExists(ctx, resourceName), + resource.TestCheckResourceAttrSet(resourceName, "primary_container.0.model_package_name"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + func TestAccSageMakerModel_containers(t *testing.T) { ctx := acctest.Context(t) rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) @@ -449,7 +476,7 @@ func testAccCheckModelExists(ctx context.Context, n string) resource.TestCheckFu } } -func testAccModelConfigBase(rName string) string { +func testAccModelConfig_base(rName string) string { return fmt.Sprintf(` resource "aws_iam_role" "test" { name = %[1]q @@ -475,7 +502,7 @@ data "aws_sagemaker_prebuilt_ecr_image" "test" { } func testAccModelConfig_basic(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -488,7 +515,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_inferenceExecution(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -509,7 +536,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_tags1(rName, tagKey1, tagValue1 string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -526,7 +553,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_tags2(rName, tagKey1, tagValue1, tagKey2, tagValue2 string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -544,7 +571,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_primaryContainerDataURL(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -605,11 +632,6 @@ resource "aws_s3_bucket" "test" { force_destroy = true } -resource "aws_s3_bucket_acl" "test" { - bucket = aws_s3_bucket.test.id - acl = "private" -} - resource "aws_s3_object" "test" { bucket = aws_s3_bucket.test.bucket key = "model.tar.gz" @@ -618,8 +640,54 @@ resource "aws_s3_object" "test" { `, rName)) } +// lintignore:AWSAT003,AWSAT005 +func testAccModelConfig_primaryContainerPackageName(rName string) string { + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` +data "aws_region" "current" {} + +locals { + region_account_map = { + us-east-1 = "865070037744" + us-east-2 = "057799348421" + us-west-1 = "382657785993" + us-west-2 = "594846645681" + ca-central-1 = "470592106596" + eu-central-1 = "446921602837" + eu-west-1 = "985815980388" + eu-west-2 = "856760150666" + eu-west-3 = "843114510376" + eu-north-1 = "136758871317" + ap-southeast-1 = "192199979996" + ap-southeast-2 = "666831318237" + ap-northeast-2 = "745090734665" + ap-northeast-1 = "977537786026" + ap-south-1 = "077584701553" + sa-east-1 = "270155090741" + } + + account = local.region_account_map[data.aws_region.current.name] + + model_package_name = format( + "arn:aws:sagemaker:%%s:%%s:model-package/hf-textgeneration-gpt2-cpu-b73b575105d336b680d151277ebe4ee0", + data.aws_region.current.name, + local.account + ) +} + +resource "aws_sagemaker_model" "test" { + name = %[1]q + enable_network_isolation = true + execution_role_arn = aws_iam_role.test.arn + + primary_container { + model_package_name = local.model_package_name + } +} +`, rName)) +} + func testAccModelConfig_primaryContainerHostname(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -633,7 +701,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_primaryContainerImage(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -650,7 +718,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_primaryContainerEnvironment(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -667,7 +735,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_primaryContainerModeSingle(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -681,7 +749,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_containers(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -698,7 +766,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_networkIsolation(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -712,7 +780,7 @@ resource "aws_sagemaker_model" "test" { } func testAccModelConfig_vpcBasic(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), acctest.ConfigAvailableAZsNoOptIn(), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), acctest.ConfigVPCWithSubnets(rName, 2), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -723,50 +791,15 @@ resource "aws_sagemaker_model" "test" { } vpc_config { - subnets = [aws_subnet.test.id, aws_subnet.bar.id] - security_group_ids = [aws_security_group.test.id, aws_security_group.bar.id] - } -} - -resource "aws_vpc" "test" { - cidr_block = "10.1.0.0/16" - - tags = { - Name = %[1]q - } -} - -resource "aws_subnet" "test" { - cidr_block = "10.1.1.0/24" - availability_zone = data.aws_availability_zones.available.names[0] - vpc_id = aws_vpc.test.id - - tags = { - Name = %[1]q - } -} - -resource "aws_subnet" "bar" { - cidr_block = "10.1.2.0/24" - availability_zone = data.aws_availability_zones.available.names[0] - vpc_id = aws_vpc.test.id - - tags = { - Name = %[1]q + subnets = aws_subnet.test[*].id + security_group_ids = aws_security_group.test[*].id } } resource "aws_security_group" "test" { - name = "%[1]s-1" - vpc_id = aws_vpc.test.id - - tags = { - Name = %[1]q - } -} + count = 2 -resource "aws_security_group" "bar" { - name = "%[1]s-2" + name = "%[1]s-${count.index}" vpc_id = aws_vpc.test.id tags = { @@ -778,7 +811,7 @@ resource "aws_security_group" "bar" { // lintignore:AWSAT003,AWSAT005 func testAccModelConfig_primaryContainerPrivateDockerRegistry(rName string) string { - return acctest.ConfigCompose(testAccModelConfigBase(rName), acctest.ConfigAvailableAZsNoOptIn(), fmt.Sprintf(` + return acctest.ConfigCompose(testAccModelConfig_base(rName), acctest.ConfigVPCWithSubnets(rName, 1), fmt.Sprintf(` resource "aws_sagemaker_model" "test" { name = %[1]q execution_role_arn = aws_iam_role.test.arn @@ -797,31 +830,13 @@ resource "aws_sagemaker_model" "test" { } vpc_config { - subnets = [aws_subnet.test.id] + subnets = aws_subnet.test[*].id security_group_ids = [aws_security_group.test.id] } } -resource "aws_vpc" "test" { - cidr_block = "10.1.0.0/16" - - tags = { - Name = %[1]q - } -} - -resource "aws_subnet" "test" { - cidr_block = "10.1.1.0/24" - availability_zone = data.aws_availability_zones.available.names[0] - vpc_id = aws_vpc.test.id - - tags = { - Name = %[1]q - } -} - resource "aws_security_group" "test" { - name = "%[1]s-1" + name = %[1]q vpc_id = aws_vpc.test.id tags = { diff --git a/website/docs/r/sagemaker_model.html.markdown b/website/docs/r/sagemaker_model.html.markdown index 1d581d96f49..1889ac38b0f 100644 --- a/website/docs/r/sagemaker_model.html.markdown +++ b/website/docs/r/sagemaker_model.html.markdown @@ -59,9 +59,10 @@ The following arguments are supported: The `primary_container` and `container` block both support: -* `image` - (Required) The registry path where the inference code image is stored in Amazon ECR. +* `image` - (Optional) The registry path where the inference code image is stored in Amazon ECR. * `mode` - (Optional) The container hosts value `SingleModel/MultiModel`. The default value is `SingleModel`. * `model_data_url` - (Optional) The URL for the S3 location where model artifacts are stored. +* `model_package_name` - (Optional) The Amazon Resource Name (ARN) of the model package to use to create the model. * `container_hostname` - (Optional) The DNS host name for the container. * `environment` - (Optional) Environment variables for the Docker container. A list of key value pairs.