Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]: actor critic proof of concept #41

Draft
wants to merge 68 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
3ef75bd
prog
and-rewsmith Mar 26, 2024
3a079ed
prog
and-rewsmith Mar 26, 2024
e996fd3
prog
and-rewsmith Mar 26, 2024
804af8a
delete useless
and-rewsmith Mar 26, 2024
ab78433
clean for matt
and-rewsmith Mar 26, 2024
a984ef5
actor/critic working
and-rewsmith Mar 27, 2024
4ef791d
curriculum
and-rewsmith Mar 27, 2024
28c77f5
wandb
and-rewsmith Mar 27, 2024
5fac99d
prog
and-rewsmith Mar 27, 2024
84e73cd
prog
and-rewsmith Mar 27, 2024
9dd94f5
prog
and-rewsmith Mar 27, 2024
64d51a5
prog
and-rewsmith Mar 28, 2024
bf14a64
incremental reward
and-rewsmith Apr 1, 2024
808cf57
todo
and-rewsmith Apr 1, 2024
b0c61e2
golden run
and-rewsmith Apr 2, 2024
bce98f5
if condition
and-rewsmith Apr 2, 2024
88a4f7d
promising
and-rewsmith Apr 2, 2024
c89b99e
rl working
and-rewsmith Apr 18, 2024
f442997
change detection basics
and-rewsmith Apr 18, 2024
40f7197
todo change detection
and-rewsmith Apr 19, 2024
35dbe83
actor and critic
and-rewsmith Apr 19, 2024
4d1a655
working but slow
and-rewsmith Apr 19, 2024
08bc3e1
running fast
and-rewsmith Apr 19, 2024
4e2d621
actor critic with pause
and-rewsmith Apr 21, 2024
5b3430e
change detection framework acc
and-rewsmith Apr 23, 2024
20d6dc6
batch change detection sort of working
and-rewsmith Apr 23, 2024
8eb4e38
working batch env
and-rewsmith Apr 23, 2024
b8778fc
fix bugs in batch
and-rewsmith Apr 23, 2024
8975f92
broken: fixing device
and-rewsmith Apr 23, 2024
3e8d6dd
floats
and-rewsmith Apr 23, 2024
b66cb99
optimize perf
and-rewsmith Apr 23, 2024
e68d836
tmp run cprofile
and-rewsmith Apr 24, 2024
073b49e
ignore
and-rewsmith May 1, 2024
8ef656a
imitiation learning
and-rewsmith May 2, 2024
26b227a
in progress image decode search
and-rewsmith May 2, 2024
0d918a5
predict size issues
and-rewsmith May 2, 2024
a92e789
working search
and-rewsmith May 2, 2024
dbd359d
working search
and-rewsmith May 3, 2024
284ac3d
improve search
and-rewsmith May 3, 2024
9c70cbe
one hot encode inputs
and-rewsmith May 3, 2024
134d721
merge conflict upstream main
and-rewsmith May 3, 2024
0401a8e
no grad for computations layer
and-rewsmith May 3, 2024
ac20e16
prof improvements lif
and-rewsmith May 3, 2024
016ced9
optimize filters
and-rewsmith May 3, 2024
d886f9b
perf fix bug
and-rewsmith May 3, 2024
f259ea4
search data ratios
and-rewsmith May 3, 2024
7448aaa
sweep parallel
and-rewsmith May 3, 2024
a882140
search ironed out
and-rewsmith May 3, 2024
9fb79a4
skip profiling
and-rewsmith May 3, 2024
047219e
constant decoder training
and-rewsmith May 4, 2024
88f0054
sliding window predictions
and-rewsmith May 4, 2024
a278e53
matt's adjustments, however need to tune some things
and-rewsmith Jun 4, 2024
885e5ba
old commenting out idk
and-rewsmith Jun 4, 2024
c151225
update hyperparameters
and-rewsmith Jun 4, 2024
cacef6a
progress
and-rewsmith Jun 4, 2024
dc3f89d
sweep config
and-rewsmith Jun 6, 2024
b9469f1
ff vanilla
and-rewsmith Jun 6, 2024
bae7193
recurrent inputs
and-rewsmith Jun 6, 2024
830214d
recurrent connectivity
and-rewsmith Jun 6, 2024
19eba63
swap RNN
and-rewsmith Jun 6, 2024
d3c9e0b
progress
and-rewsmith Jun 7, 2024
2c27090
progress
and-rewsmith Jun 7, 2024
b1717c4
progress
and-rewsmith Jun 7, 2024
b816ad4
change seed
and-rewsmith Jun 7, 2024
848869e
prog
and-rewsmith Jun 7, 2024
a8ad8e8
tmp
and-rewsmith Jun 9, 2024
267fd1a
tmp
and-rewsmith Jun 9, 2024
e2f6d9d
progress
and-rewsmith Jun 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
sliding window predictions
and-rewsmith committed May 4, 2024

Verified

This commit was signed with the committer’s verified signature.
joyeecheung Joyee Cheung
commit 88f0054d9653a1e7d53ff81a21699883c262b102
39 changes: 24 additions & 15 deletions benchmarks/src/image_decode_search.py
Original file line number Diff line number Diff line change
@@ -31,8 +31,8 @@

# TODOPRE: Think about trade off between high and 1 batch size
BATCH_SIZE = 128
DECODER_EPOCHS_PER_TRIAL = 5
DECODER_LR = 0.001
DECODER_EPOCHS_PER_TRIAL = 7
DECODER_LR = 0.01
DEVICE = "mps"
NUM_SEEDS_BENCH = 1
datetime_str = time.strftime("%Y%m%d-%H%M%S")
@@ -65,7 +65,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.softmax(x, dim=-1)
# return: (batch_size, num_switches, num_classes)

def train(self, internal_state: torch.Tensor, labels: torch.Tensor, image_count: int, num_timesteps_each_image: int, num_epochs: int = DECODER_EPOCHS_PER_TRIAL):
def train(self, internal_state: torch.Tensor, labels: torch.Tensor, image_count: int, num_timesteps_each_image: int, labels_prev: torch.Tensor, num_epochs: int = DECODER_EPOCHS_PER_TRIAL):
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)

@@ -75,19 +75,21 @@ def train(self, internal_state: torch.Tensor, labels: torch.Tensor, image_count:
# outputs: (batch_size, num_switches, num_classes)

num_images_seen = math.ceil(image_count / num_timesteps_each_image)
outputs = outputs[:, :num_images_seen, :]
labels_clone = labels.clone().detach()
labels_clone = labels_clone[:, :num_images_seen]
labels_prev_clone = labels_prev.clone().detach()
labels_prev_clone = labels_prev_clone[:, num_images_seen:]
merged_labels = torch.cat((labels_prev_clone, labels_clone), dim=1)

# labels: (batch_size, num_images_seen)

# Flatten outputs and labels for the loss calculation
outputs = outputs.reshape(-1, self.num_classes)
# outputs: (batch_size * num_switches, num_classes)
labels_clone = labels_clone.reshape(-1)
merged_labels = merged_labels.reshape(-1)
# labels: (batch_size * num_switches,)

loss = criterion(outputs, labels_clone)
loss = criterion(outputs, merged_labels)
wandb.log({"loss": loss.item()})
# loss: scalar

@@ -234,26 +236,33 @@ def bench_specific_seed(running_log: TextIO,
decoder = Decoder(input_size=sum(layer_sizes), num_switches=dataset.num_switches,
num_classes=dataset.num_classes, device=DEVICE, hidden_sizes=[100, 50, 20])

prev_labels = None
batch_count = 0
for batch, labels in tqdm(train_dataloader):
labels_untouched = labels.clone().detach()
# batch: (num_timesteps, batch_size, data_size)
# labels: (batch_size,)
batch = batch.permute(1, 0, 2)
# batch: (batch_size, num_timesteps, data_size)
count = 0
timestep_count = 0
for timestep_data in tqdm(batch, leave=False):
count += 1
timestep_count += 1 # position here matters: predict current image rather than one timestep ago
# timestep_data: (batch_size, data_size)
net.process_data_single_timestep(timestep_data)

layer_activations = net.layer_activations()
# layer_activations: List[torch.Tensor], each tensor of shape (batch_size, layer_size)
internal_state = torch.concat(layer_activations, dim=1)
# internal_state: (batch_size, sum(layer_sizes))
if batch_count >= 1:
layer_activations = net.layer_activations()
# layer_activations: List[torch.Tensor], each tensor of shape (batch_size, layer_size)
internal_state = torch.concat(layer_activations, dim=1)
# internal_state: (batch_size, sum(layer_sizes))

# Reshape labels to (batch_size, num_switches)
labels = labels.view(-1, dataset.num_switches)
# Reshape labels to (batch_size, num_switches)
labels = labels.view(-1, dataset.num_switches)

decoder.train(internal_state, labels, count, dataset.num_timesteps_each_image)
decoder.train(internal_state, labels, timestep_count, dataset.num_timesteps_each_image, prev_labels)

prev_labels = labels_untouched
batch_count += 1

dataset = ImageDataset(
num_timesteps_each_image=20,
2 changes: 1 addition & 1 deletion datasets/src/image_detection/dataset.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ def __init__(self,
num_timesteps_each_image: int,
num_switches: int,
device: str,
max_samples: int = 1024 * 5) -> None:
max_samples: int = 1024 * 2) -> None:
self.num_classes = 10
self.num_timesteps_each_image = num_timesteps_each_image
self.num_switches = num_switches
Loading