diff --git a/pkg/finetune/target_unsloth.py b/pkg/finetune/target_unsloth.py index e010d773..e466bb58 100644 --- a/pkg/finetune/target_unsloth.py +++ b/pkg/finetune/target_unsloth.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from unsloth import is_bfloat16_supported +from transformers import TrainingArguments, DataCollatorForSeq2Seq from unsloth import FastLanguageModel import torch from trl import SFTTrainer @@ -71,13 +73,14 @@ def formatting_prompts_func(examples): else: dataset = load_dataset(source, split = "train") -dataset = dataset.map(formatting_prompts_func, batched = True) +dataset = dataset.map(formatting_prompts_func, batched=True) trainer = SFTTrainer( model=model, train_dataset=dataset, dataset_text_field="text", max_seq_length=max_seq_length, + data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), tokenizer=tokenizer, dataset_num_proc = 2, packing = cfg.get('packing'), # Can make training 5x faster for short sequences. @@ -87,8 +90,8 @@ def formatting_prompts_func(examples): warmup_steps=cfg.get('warmupSteps'), max_steps=cfg.get('maxSteps'), learning_rate = cfg.get('learningRate'), - fp16=not torch.cuda.is_bf16_supported(), - bf16=torch.cuda.is_bf16_supported(), + fp16=not is_bfloat16_supported(), + bf16=is_bfloat16_supported(), logging_steps=cfg.get('loggingSteps'), optim=cfg.get('optimizer'), weight_decay = cfg.get('weightDecay'), diff --git a/test/aikitfile-unsloth.yaml b/test/aikitfile-unsloth.yaml index b23c3f18..bd75ab64 100644 --- a/test/aikitfile-unsloth.yaml +++ b/test/aikitfile-unsloth.yaml @@ -1,6 +1,6 @@ #syntax=aikit:test apiVersion: v1alpha1 -baseModel: unsloth/llama-3-8b-bnb-4bit +baseModel: unsloth/Meta-Llama-3.1-8B datasets: - source: "yahma/alpaca-cleaned" type: alpaca