Skip to content

Commit

Permalink
use llama 3.1 for test
Browse files Browse the repository at this point in the history
Signed-off-by: Sertac Ozercan <[email protected]>
  • Loading branch information
sozercan committed Sep 28, 2024
1 parent ec787eb commit 8e9dc49
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions pkg/finetune/target_unsloth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

from unsloth import is_bfloat16_supported
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
Expand Down Expand Up @@ -71,13 +73,14 @@ def formatting_prompts_func(examples):
else:
dataset = load_dataset(source, split = "train")

dataset = dataset.map(formatting_prompts_func, batched = True)
dataset = dataset.map(formatting_prompts_func, batched=True)

trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
tokenizer=tokenizer,
dataset_num_proc = 2,
packing = cfg.get('packing'), # Can make training 5x faster for short sequences.
Expand All @@ -87,8 +90,8 @@ def formatting_prompts_func(examples):
warmup_steps=cfg.get('warmupSteps'),
max_steps=cfg.get('maxSteps'),
learning_rate = cfg.get('learningRate'),
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=cfg.get('loggingSteps'),
optim=cfg.get('optimizer'),
weight_decay = cfg.get('weightDecay'),
Expand Down
2 changes: 1 addition & 1 deletion test/aikitfile-unsloth.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#syntax=aikit:test
apiVersion: v1alpha1
baseModel: unsloth/llama-3-8b-bnb-4bit
baseModel: unsloth/Meta-Llama-3.1-8B
datasets:
- source: "yahma/alpaca-cleaned"
type: alpaca
Expand Down

0 comments on commit 8e9dc49

Please sign in to comment.