-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Address sample_sequence_length
greater than min_length_time_axis
#45
base: main
Are you sure you want to change the base?
Address sample_sequence_length
greater than min_length_time_axis
#45
Conversation
I think the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why its necessary to replace min_length
with add_batch_size
. Otherwise it makes sense, please just see my comment on the merge conflict.
buffer = flat_buffer.make_flat_buffer( | ||
max_length, min_length, sample_batch_size, False, int(min_length + 10) | ||
max_length, add_batch_size, sample_batch_size, False, add_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reasoning for this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length_time_axis = 3, min_length_time_axis = 1, add_batch_size = 75, sample_sequence_length = 2, period = 1, max_size = None
def validate_trajectory_buffer_args(
max_length_time_axis: Optional[int],
min_length_time_axis: int,
add_batch_size: int,
sample_sequence_length: int,
period: int,
max_size: Optional[int],
) -> None:
"""Validate the arguments of the trajectory buffer."""
validate_size(max_length_time_axis, max_size, add_batch_size)
if max_size is not None:
max_length_time_axis = max_size // add_batch_size
if sample_sequence_length > min_length_time_axis:
> warnings.warn(
"`sample_sequence_length` greater than `min_length_time_axis`, therefore "
"overriding `min_length_time_axis`"
"to be set to `sample_sequence_length`, as we need at least `sample_sequence_length` "
"timesteps added to the buffer before we can sample.",
stacklevel=1,
)
E UserWarning: `sample_sequence_length` greater than `min_length_time_axis`, therefore overriding `min_length_time_axis`to be set to `sample_sequence_length`, as we need at least `sample_sequence_length` timesteps added to the buffer before we can sample.
The warning checks that sample_sequence_length <= min_length_time_axis
. The flat buffer creates a trajectory buffer with sample_sequence_length = 2
and min_length_time_axis = min_length // add_batch_size + 1 = 1 // (1 + 10) + 1 = 0 + 1 = 1
. When calling create trajectory buffer you need to ensure that min_length >= add_batch_size
hence the change. Perhaps it is not the right solution.
@@ -92,7 +92,7 @@ def test_mixed_trajectory_sample( | |||
for i in range(3): | |||
buffer = trajectory_buffer.make_trajectory_buffer( | |||
max_length_time_axis=200 * (i + 1), | |||
min_length_time_axis=0, | |||
min_length_time_axis=sample_sequence_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a sensible change.
c485ecf
to
082261d
Compare
Based on #43, please review that first and if it wasn't merged yet, please review this PR per commit.
The goal of this PR is to make the tests not generate flashbax warnings and to prevent code from being added that would trigger a flashbax warning.
I am not sure that the suggested changes in the last commit are right.
I am unsure why
min_length_time_axis=min_length // add_batch_size + 1
is used in the trajectory buffer source code, particularly the+ 1
. I am also unsure why themin_length
fixture is defined as:I could understand if it was defined in terms of
add_batch_size
, but it is not.