Skip to content

Commit

Permalink
Merge pull request opendatahub-io#122 from MichaelClifford/kfp-pytorc…
Browse files Browse the repository at this point in the history
…h-job

Add training args to pipeline
  • Loading branch information
Shreyanand authored Oct 23, 2024
2 parents 7f22269 + 3909274 commit abe3dc3
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
32 changes: 32 additions & 0 deletions pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
MAX_WORKERS = "auto"
MERGE_SYSTEM_USER_MESSAGE = False

# training args
NUM_EPOCHS_PHASE_1 = 2
NUM_EPOCHS_PHASE_2 = 2
EFFECTIVE_BATCH_SIZE = 3840
LEARNING_RATE = 1e-4
NUM_WARMUP_STEPS = 800
SAVE_SAMPLES = 0
MAX_BATCH_LEN = 20000
SEED = 42


def pipeline_wrapper(mock: List[Literal[MOCKED_STAGES]]):
"""Wrapper for KFP pipeline, which allows for mocking individual stages."""
Expand Down Expand Up @@ -94,6 +104,14 @@ def pipeline(
device: str = None,
nproc_per_node: int = 3,
nnodes: int = 2,
num_epochs_phase_1: int = NUM_EPOCHS_PHASE_1,
num_epochs_phase_2: int = NUM_EPOCHS_PHASE_2,
effective_batch_size: int = EFFECTIVE_BATCH_SIZE,
learning_rate: float = LEARNING_RATE,
num_warmup_steps: int = NUM_WARMUP_STEPS,
save_samples: int = SAVE_SAMPLES,
max_batch_len: int = MAX_BATCH_LEN,
seed: int = SEED,
):
# SDG stage
git_clone_task = git_clone_op(
Expand Down Expand Up @@ -185,6 +203,13 @@ def pipeline(
phase_num=1,
nproc_per_node=nproc_per_node,
nnodes=nnodes,
num_epochs=num_epochs_phase_1,
effective_batch_size=effective_batch_size,
learning_rate=learning_rate,
num_warmup_steps=num_warmup_steps,
save_samples=save_samples,
max_batch_len=max_batch_len,
seed=seed,
)
pytorchjob_manifest_task.set_caching_options(False)

Expand Down Expand Up @@ -255,6 +280,13 @@ def pipeline(
phase_num=2,
nproc_per_node=nproc_per_node,
nnodes=nnodes,
num_epochs=num_epochs_phase_2,
effective_batch_size=effective_batch_size,
learning_rate=learning_rate,
num_warmup_steps=num_warmup_steps,
save_samples=save_samples,
max_batch_len=max_batch_len,
seed=seed,
)

pytorchjob_manifest_2_task.set_caching_options(False)
Expand Down
68 changes: 68 additions & 0 deletions pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@
# base_model: str [Default: 'ibm-granite/granite-7b-base']
# batch_size: int [Default: 8.0]
# device: str
# effective_batch_size: int [Default: 3840.0]
# few_shots: int [Default: 5.0]
# learning_rate: float [Default: 0.0001]
# max_batch_len: int [Default: 20000.0]
# max_workers: str [Default: 'auto']
# merge_system_user_message: bool [Default: False]
# model_dtype: str [Default: 'bfloat16']
# nnodes: int [Default: 2.0]
# nproc_per_node: int [Default: 3.0]
# num_epochs_phase_1: int [Default: 2.0]
# num_epochs_phase_2: int [Default: 2.0]
# num_instructions_to_generate: int [Default: 2.0]
# num_warmup_steps: int [Default: 800.0]
# repo_branch: str
# repo_pr: int
# repo_url: str [Default: 'https://github.com/instructlab/taxonomy.git']
# save_samples: int [Default: 0.0]
# seed: int [Default: 42.0]
# storage_class_name: str [Default: 'nfs-csi']
components:
comp-artifact-to-pvc-op:
Expand Down Expand Up @@ -1931,10 +1939,16 @@ root:
- createpvc-3
inputs:
parameters:
effective_batch_size:
componentInputParameter: effective_batch_size
input_pvc_name:
taskOutputParameter:
outputParameterKey: name
producerTask: createpvc-2
learning_rate:
componentInputParameter: learning_rate
max_batch_len:
componentInputParameter: max_batch_len
model_pvc_name:
taskOutputParameter:
outputParameterKey: name
Expand All @@ -1947,13 +1961,21 @@ root:
componentInputParameter: nnodes
nproc_per_node:
componentInputParameter: nproc_per_node
num_epochs:
componentInputParameter: num_epochs_phase_1
num_warmup_steps:
componentInputParameter: num_warmup_steps
output_pvc_name:
taskOutputParameter:
outputParameterKey: name
producerTask: createpvc-3
phase_num:
runtimeValue:
constant: 1.0
save_samples:
componentInputParameter: save_samples
seed:
componentInputParameter: seed
taskInfo:
name: pytorchjob-manifest-op
pytorchjob-manifest-op-2:
Expand All @@ -1967,10 +1989,16 @@ root:
- kubectl-wait-for-op
inputs:
parameters:
effective_batch_size:
componentInputParameter: effective_batch_size
input_pvc_name:
taskOutputParameter:
outputParameterKey: name
producerTask: createpvc-2
learning_rate:
componentInputParameter: learning_rate
max_batch_len:
componentInputParameter: max_batch_len
model_pvc_name:
taskOutputParameter:
outputParameterKey: name
Expand All @@ -1983,13 +2011,21 @@ root:
componentInputParameter: nnodes
nproc_per_node:
componentInputParameter: nproc_per_node
num_epochs:
componentInputParameter: num_epochs_phase_2
num_warmup_steps:
componentInputParameter: num_warmup_steps
output_pvc_name:
taskOutputParameter:
outputParameterKey: name
producerTask: createpvc-3
phase_num:
runtimeValue:
constant: 2.0
save_samples:
componentInputParameter: save_samples
seed:
componentInputParameter: seed
taskInfo:
name: pytorchjob-manifest-op-2
run-final-eval-op:
Expand Down Expand Up @@ -2097,10 +2133,22 @@ root:
device:
isOptional: true
parameterType: STRING
effective_batch_size:
defaultValue: 3840.0
isOptional: true
parameterType: NUMBER_INTEGER
few_shots:
defaultValue: 5.0
isOptional: true
parameterType: NUMBER_INTEGER
learning_rate:
defaultValue: 0.0001
isOptional: true
parameterType: NUMBER_DOUBLE
max_batch_len:
defaultValue: 20000.0
isOptional: true
parameterType: NUMBER_INTEGER
max_workers:
defaultValue: auto
isOptional: true
Expand All @@ -2121,10 +2169,22 @@ root:
defaultValue: 3.0
isOptional: true
parameterType: NUMBER_INTEGER
num_epochs_phase_1:
defaultValue: 2.0
isOptional: true
parameterType: NUMBER_INTEGER
num_epochs_phase_2:
defaultValue: 2.0
isOptional: true
parameterType: NUMBER_INTEGER
num_instructions_to_generate:
defaultValue: 2.0
isOptional: true
parameterType: NUMBER_INTEGER
num_warmup_steps:
defaultValue: 800.0
isOptional: true
parameterType: NUMBER_INTEGER
repo_branch:
isOptional: true
parameterType: STRING
Expand All @@ -2135,6 +2195,14 @@ root:
defaultValue: https://github.com/instructlab/taxonomy.git
isOptional: true
parameterType: STRING
save_samples:
defaultValue: 0.0
isOptional: true
parameterType: NUMBER_INTEGER
seed:
defaultValue: 42.0
isOptional: true
parameterType: NUMBER_INTEGER
storage_class_name:
defaultValue: nfs-csi
isOptional: true
Expand Down

0 comments on commit abe3dc3

Please sign in to comment.