From d21930bd73d6350fbdbb474d07e90baaeb78bfb8 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Mon, 25 Nov 2024 22:22:46 +0100 Subject: [PATCH] should be good to go, included unit test --- pkg/serdes.go | 7 +++++- tests/client_db_test.go | 50 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/pkg/serdes.go b/pkg/serdes.go index 44e40c4..b80392e 100644 --- a/pkg/serdes.go +++ b/pkg/serdes.go @@ -264,7 +264,12 @@ func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result } func getHTTPRequest(api_url string, api_key string, request dbRequest) *http.Request { - request_url, _ := http.NewRequest("GET", api_url+"images/", nil) + if request.randomSampling { + api_url += "images/random/" + } else { + api_url += "images/" + } + request_url, _ := http.NewRequest("GET", api_url, nil) request_url.Header.Add("Authorization", "Token "+api_key) req := request_url.URL.Query() diff --git a/tests/client_db_test.go b/tests/client_db_test.go index e36efb2..154be50 100644 --- a/tests/client_db_test.go +++ b/tests/client_db_test.go @@ -20,7 +20,6 @@ func get_default_test_config() datago.DatagoConfig { db_config := datago.GetSourceDBConfig() db_config.Sources = get_test_source() db_config.PageSize = 32 - config.SourceConfig = db_config return config } @@ -379,3 +378,52 @@ func TestMultipleSources(t *testing.T) { } client.Stop() } + +func TestRandomSampling(t *testing.T) { + clientConfig := get_default_test_config() + clientConfig.SamplesBufferSize = 1 + dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig) + dbConfig.RandomSampling = true + clientConfig.SourceConfig = dbConfig + + // Fill in two sets with some results + sample_set_1 := make(map[string]interface{}) + sample_set_2 := make(map[string]interface{}) + + { + client := datago.GetClient(clientConfig) + + for i := 0; i < 10; i++ { + sample := client.GetSample() + sample_set_1[sample.ID] = nil + } + client.Stop() + } + + { + client := datago.GetClient(clientConfig) + + for i := 0; i < 10; i++ { + sample := client.GetSample() + sample_set_2[sample.ID] = nil + } + client.Stop() + } + + // Check that the two sets are different + setsAreEqual := func(set1, set2 map[string]interface{}) bool { + if len(set1) != len(set2) { + return false + } + for k := range set1 { + if _, exists := set2[k]; !exists { + return false + } + } + return true + } + + if setsAreEqual(sample_set_1, sample_set_2) { + t.Error("Random sampling is not working") + } +}