Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly encode Duration and optional fields #72

Merged
merged 11 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# matrix-sdk-crypto-wasm v4.0.0

- Properly encode missing and `Duration` parameters in requests.
([#72](https://github.com/matrix-org/matrix-rust-sdk-crypto-wasm/pull/72))

**BREAKING CHANGES**

- Rename `OlmMachine.init_from_store` introduced in v3.6.0 to
Expand Down
31 changes: 31 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ tracing-subscriber = { version = "0.3.14", default-features = false, features =
wasm-bindgen = "0.2.89"
wasm-bindgen-futures = "0.4.33"
zeroize = "1.6.0"
wasm-bindgen-test = "0.3.37"

[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"lint:types": "tsc --noEmit",
"build": "WASM_PACK_ARGS=--release ./scripts/build.sh",
"build:dev": "WASM_PACK_ARGS=--dev ./scripts/build.sh",
"test": "jest --verbose",
"test": "jest --verbose && yarn run wasm-pack test --node",
"doc": "typedoc --tsconfig .",
"prepack": "npm run build && npm run test"
}
Expand Down
116 changes: 109 additions & 7 deletions src/requests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Types to handle requests.

use std::time::Duration;

use js_sys::JsString;
use matrix_sdk_common::ruma::{
api::client::keys::{
Expand All @@ -8,6 +10,7 @@ use matrix_sdk_common::ruma::{
upload_signatures::v3::Request as OriginalSignatureUploadRequest,
},
events::EventContent,
exports::serde::ser::Error,
};
use matrix_sdk_crypto::{
requests::{
Expand Down Expand Up @@ -316,7 +319,7 @@ macro_rules! request {
(
$destination_request:ident from $source_request:ident
$( extracts $( $field_name:ident : $field_type:tt ),+ $(,)? )?
$( $( and )? groups $( $grouped_field_name:ident ),+ $(,)? )?
$( $( and )? groups $( $grouped_field_name:ident $( { $transformation:expr } )? $( $optional:literal )? ),+ $(,)? )?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding "optional" to flag things as optional. But the match will match any literal, which is suboptimal. I'd prefer to match just optional, but I don't see any way of doing that. If I use $( optional ), then that would match just optional, but then I don't have a way to expand that back out it the transcription below, because it isn't assigned to a metavariable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we check at runtime that the token is optional? It would be nice to have some kind of check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I guess you could have two alternatives, one with optional and one without, and give the metavariables different names? I've not done many of these macros so I'm just guessing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hywan as original author of the request! macro, do you have any suggestions for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about simply optional?

$( $( and )? groups $( $grouped_field_name:ident $( { $transformation:expr } )? $( optional )? ),+ $(,)? )?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I tried doing $( optional )?, but I couldn't figure out how to handle the expansion so that it does something different depending on whether it's specified or not, because it isn't bound to a metavariable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep things as they are now, I may update this later.

) => {

impl TryFrom<(String, &$source_request)> for $destination_request {
Expand All @@ -329,7 +332,7 @@ macro_rules! request {
@__try_from $destination_request from $source_request
(request_id = request_id.into(), request = request)
$( extracts [ $( $field_name : $field_type, )+ ] )?
$( groups [ $( $grouped_field_name, )+ ] )?
$( groups [ $( $grouped_field_name $( { $transformation } )? $( $optional )? , )+ ] )?
)
}
}
Expand All @@ -339,7 +342,7 @@ macro_rules! request {
@__try_from $destination_request:ident from $source_request:ident
(request_id = $request_id:expr, request = $request:expr)
$( extracts [ $( $field_name:ident : $field_type:tt ),* $(,)? ] )?
$( groups [ $( $grouped_field_name:ident ),* $(,)? ] )?
$( groups [ $( $grouped_field_name:ident $( { $transformation:expr } )? $( $optional:literal )? ),* $(,)? ] )?
) => {
{
Ok($destination_request {
Expand All @@ -353,7 +356,15 @@ macro_rules! request {
body: {
let mut map = serde_json::Map::new();
$(
map.insert(stringify!($grouped_field_name).to_owned(), serde_json::to_value(&$request.$grouped_field_name).unwrap());
let field = &$request.$grouped_field_name;
$(
let field = {
let $grouped_field_name = field;

$transformation
};
)?
request!(@__set_field $( $optional )? map : $grouped_field_name = field);
)*
let object = serde_json::Value::Object(map);

Expand All @@ -379,15 +390,25 @@ macro_rules! request {
( @__field_type as event_type ; request = $request:expr, field_name = $field_name:ident ) => {
$request.content.event_type().to_string().into()
};

( @__set_field $optional:literal $map:ident : $grouped_field_name:ident = $field:ident) => {
if let Some($field) = $field {
request!(@__set_field $map : $grouped_field_name = $field);
}
};

( @__set_field $map:ident : $grouped_field_name:ident = $field:ident) => {
$map.insert(stringify!($grouped_field_name).to_owned(), serde_json::to_value($field).unwrap());
};
}

// Generate the methods needed to convert rust `OutgoingRequests` into the js
// counterpart. Technically it's converting tuples `(String, &Original)`, where
// the first member is the request ID, into js requests. Used by
// `TryFrom<OutgoingRequest> for JsValue`.
request!(KeysUploadRequest from OriginalKeysUploadRequest groups device_keys, one_time_keys, fallback_keys);
request!(KeysQueryRequest from OriginalKeysQueryRequest groups timeout, device_keys);
request!(KeysClaimRequest from OriginalKeysClaimRequest groups timeout, one_time_keys);
request!(KeysUploadRequest from OriginalKeysUploadRequest groups device_keys "optional", one_time_keys, fallback_keys);
request!(KeysQueryRequest from OriginalKeysQueryRequest groups timeout { timeout.as_ref().map(Duration::as_millis).map(u64::try_from).transpose().map_err(serde_json::Error::custom)? } "optional", device_keys);
request!(KeysClaimRequest from OriginalKeysClaimRequest groups timeout { timeout.as_ref().map(Duration::as_millis).map(u64::try_from).transpose().map_err(serde_json::Error::custom)? } "optional", one_time_keys);
request!(ToDeviceRequest from OriginalToDeviceRequest extracts event_type: string, txn_id: string and groups messages);
request!(RoomMessageRequest from OriginalRoomMessageRequest extracts room_id: string, txn_id: string, event_type: event_type, content: json);
request!(KeysBackupRequest from OriginalKeysBackupRequest extracts version: string and groups rooms);
Expand Down Expand Up @@ -619,3 +640,84 @@ impl TryFrom<OriginalCrossSigningBootstrapRequests> for CrossSigningBootstrapReq
})
}
}

#[cfg(test)]
pub(crate) mod tests {
use std::collections::BTreeMap;

use matrix_sdk_common::ruma::{
api::client::keys::{
claim_keys::v3::Request as OriginalKeysClaimRequest,
upload_keys::v3::Request as OriginalKeysUploadRequest,
},
device_id, user_id, DeviceKeyAlgorithm,
};
use matrix_sdk_crypto::requests::KeysQueryRequest as OriginalKeysQueryRequest;
use serde_json::Value;
use wasm_bindgen_test::wasm_bindgen_test;

use super::{KeysClaimRequest, KeysQueryRequest, KeysUploadRequest};

#[wasm_bindgen_test]
// make sure that the timeout in a /keys/claim request is encoded as a number
fn test_keys_claim_request_with_timeout() {
let rust_request = OriginalKeysClaimRequest::new(BTreeMap::from([(
user_id!("@alice:localhost").to_owned(),
BTreeMap::from([(
device_id!("ABCDEFG").to_owned(),
DeviceKeyAlgorithm::SignedCurve25519,
)]),
)]));
let request = KeysClaimRequest::try_from(("ID".to_string(), &rust_request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(body.as_object().unwrap().contains_key("timeout"));
assert!(body["timeout"].is_number());
}

#[wasm_bindgen_test]
// if a /keys/claim request has no timeout, make sure it isn't in the request
fn test_keys_claim_request_without_timeout() {
let mut rust_request = OriginalKeysClaimRequest::new(BTreeMap::from([(
user_id!("@alice:localhost").to_owned(),
BTreeMap::from([(
device_id!("ABCDEFG").to_owned(),
DeviceKeyAlgorithm::SignedCurve25519,
)]),
)]));
rust_request.timeout = None;
let request = KeysClaimRequest::try_from(("ID".to_string(), &rust_request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(!body.as_object().unwrap().contains_key("timeout"));
}

#[wasm_bindgen_test]
// make sure that the timeout is encoded as a number in a /keys/query
fn test_keys_query_request_with_timeout() {
let rust_request = OriginalKeysQueryRequest {
timeout: Some(std::time::Duration::from_secs(10)),
device_keys: BTreeMap::new(),
};
let request = KeysQueryRequest::try_from(("ID".to_string(), &rust_request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(body.as_object().unwrap().contains_key("timeout"));
assert!(body["timeout"].is_number());
}

#[wasm_bindgen_test]
// if a /keys/query request has no timeout, make sure it isn't in the request
fn test_keys_query_request_without_timeout() {
let rust_request = OriginalKeysQueryRequest { timeout: None, device_keys: BTreeMap::new() };
let request = KeysQueryRequest::try_from(("ID".to_string(), &rust_request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(!body.as_object().unwrap().contains_key("timeout"));
}

#[wasm_bindgen_test]
// if a /keys/upload request no device_keys, make sure it isn't in the request
fn test_keys_upload_request_without_devices() {
let request = OriginalKeysUploadRequest::new();
let request = KeysUploadRequest::try_from(("ID".to_string(), &request)).unwrap();
let body: Value = serde_json::from_str(&String::from(request.body)).unwrap();
assert!(!body.as_object().unwrap().contains_key("device_keys"));
}
}
3 changes: 2 additions & 1 deletion tests/machine.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ describe(OlmMachine.name, () => {
expect(outgoingRequests[1].body).toBeDefined();

const body = JSON.parse(outgoingRequests[1].body);
expect(body.timeout).toBeDefined();
// default timeout in Rust is None, so timeout will be omitted
expect(body.timeout).not.toBeDefined();
expect(body.device_keys).toBeDefined();
}
});
Expand Down