Skip to content

Commit

Permalink
graph: interface: fix memory size for u4/s4
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoLv committed Dec 3, 2024
1 parent 6efa272 commit a595116
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/graph/interface/logical_tensor.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2023 Intel Corporation
* Copyright 2020-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,7 +38,9 @@ size_t logical_tensor_wrapper_t::size() const {
static_cast<size_t>(strided_pdim * effective_stride));
}

return max_size * data_type_size();
size_t data_size = utils::div_up(
max_size * data_type_size(), sub_byte_data_type_multiplier());
return data_size;
} else if (is_opaque()) {
size_t layout_id = lt->layout.layout_id;
auto backend
Expand Down
9 changes: 8 additions & 1 deletion src/graph/interface/logical_tensor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2023 Intel Corporation
* Copyright 2020-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -154,6 +154,13 @@ struct logical_tensor_wrapper_t {
/* check_dtype = */ true);
}

/** For sub-byte data types returns number of elements per byte.
* For the rest data types returns 1. */
size_t sub_byte_data_type_multiplier() const {
if (utils::one_of(data_type(), data_type::s4, data_type::u4)) return 2;
return 1;
}

// return the size of data type
size_t data_type_size() const { return types::data_type_size(data_type()); }

Expand Down
6 changes: 6 additions & 0 deletions tests/gtests/graph/api/test_cpp_api_logical_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,10 @@ TEST(APILogicalTensor, LogicalTensorSize) {
ASSERT_EQ(lt_3.get_id(), id);
ASSERT_EQ(lt_3.get_data_type(), data_type::s8);
ASSERT_EQ(lt_3.get_mem_size(), num_elem * sizeof(int8_t));

logical_tensor lt_4 {id, data_type::s4, shape, layout_type::strided};
ASSERT_EQ(lt_4.get_id(), id);
ASSERT_EQ(lt_4.get_data_type(), data_type::s4);
// in case num_elem is not even.
ASSERT_EQ(lt_4.get_mem_size(), (num_elem + 1) / 2);
}

0 comments on commit a595116

Please sign in to comment.