Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote search by tags #6708

Merged
merged 1 commit into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tribler-common/tribler_common/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_extract_tags():
assert extract_tags('####') == (set(), '####')

assert extract_tags('#tag') == ({'tag'}, '')
assert extract_tags('#Tag') == ({'tag'}, '')
assert extract_tags('a #tag in the middle') == ({'tag'}, 'a in the middle')
assert extract_tags('at the end of the query #tag') == ({'tag'}, 'at the end of the query ')
assert extract_tags('multiple tags: #tag1 #tag2#tag3') == ({'tag1', 'tag2', 'tag3'}, 'multiple tags: ')
Expand Down
3 changes: 2 additions & 1 deletion src/tribler-common/tribler_common/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def extract_tags(text: str) -> Tuple[Set[str], str]:
positions = [0]

for m in tags_re.finditer(text):
tags.add(m.group(0)[1:])
tag = m.group(0)[1:]
tags.add(tag.lower())
positions.extend(itertools.chain.from_iterable(m.regs))
positions.append(len(text))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tribler_core.components.ipv8.ipv8_component import INFINITE, Ipv8Component
from tribler_core.components.metadata_store.metadata_store_component import MetadataStoreComponent
from tribler_core.components.reporter.reporter_component import ReporterComponent
from tribler_core.components.tag.tag_component import TagComponent


class GigaChannelComponent(Component):
Expand All @@ -24,6 +25,7 @@ async def run(self):

self._ipv8_component = await self.require_component(Ipv8Component)
metadata_store_component = await self.require_component(MetadataStoreComponent)
tag_component = await self.get_component(TagComponent)

giga_channel_cls = GigaChannelTestnetCommunity if config.general.testnet else GigaChannelCommunity
community = giga_channel_cls(
Expand All @@ -35,6 +37,7 @@ async def run(self):
rqc_settings=config.remote_query_community,
metadata_store=metadata_store_component.mds,
max_peers=50,
tags_db=tag_component.tags_db if tag_component else None
)
self.community = community
self._ipv8_component.initialise_community_by_default(community, default_random_walk_max_peers=30)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import struct
from asyncio import Future
from binascii import unhexlify
from typing import List, Optional, Set

from ipv8.lazy_community import lazy_wrapper
from ipv8.messaging.lazy_payload import VariablePayload, vp_compile
from ipv8.requestcache import NumberCache, RandomNumberCache, RequestCache

from pony.orm import db_session
from pony.orm.dbapiprovider import OperationalError

from tribler_core.components.ipv8.tribler_community import TriblerCommunity
Expand All @@ -17,6 +19,7 @@
from tribler_core.components.metadata_store.remote_query_community.payload_checker import ObjState
from tribler_core.components.metadata_store.remote_query_community.settings import RemoteQueryCommunitySettings
from tribler_core.components.metadata_store.utils import RequestTimeoutException
from tribler_core.components.tag.community.tag_validator import is_valid_tag
from tribler_core.utilities.unicode import hexlify

BINARY_FIELDS = ("infohash", "channel_pk")
Expand Down Expand Up @@ -129,12 +132,13 @@ class RemoteQueryCommunity(TriblerCommunity, EVAProtocolMixin):
def __init__(self, my_peer, endpoint, network,
rqc_settings: RemoteQueryCommunitySettings = None,
metadata_store=None,
tags_db=None,
**kwargs):
super().__init__(my_peer, endpoint, network=network, **kwargs)

self.rqc_settings = rqc_settings
self.mds: MetadataStore = metadata_store

self.tags_db = tags_db
# This object stores requests for "select" queries that we sent to other hosts.
# We keep track of peers we actually requested for data so people can't randomly push spam at us.
# Also, this keeps track of hosts we responded to. There is a possibility that
Expand Down Expand Up @@ -188,8 +192,23 @@ async def process_rpc_query(self, json_bytes: bytes):
:raises ValueError: if no JSON could be decoded.
:raises pony.orm.dbapiprovider.OperationalError: if an illegal query was performed.
"""
request_sanitized = sanitize_query(json.loads(json_bytes), self.rqc_settings.max_response_size)
return await self.mds.get_entries_threaded(**request_sanitized)
parameters = json.loads(json_bytes)
sanitized_parameters = sanitize_query(parameters, self.rqc_settings.max_response_size)

# tags should be extracted because `get_entries_threaded` doesn't expect them as a parameter
tags = sanitized_parameters.pop('tags', None)

infohash_set = await self.mds.run_threaded(self.search_for_tags, tags)
sanitized_parameters['infohash_set'] = infohash_set # it could be None, it is expected

return await self.mds.get_entries_threaded(**sanitized_parameters)

@db_session
def search_for_tags(self, tags: Optional[List[str]]) -> Optional[Set[bytes]]:
if not tags or not self.tags_db:
return None
valid_tags = {tag for tag in tags if is_valid_tag(tag)}
return self.tags_db.get_infohashes(valid_tags)

def send_db_results(self, peer, request_payload_id, db_results, force_eva_response=False):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def setUp(self):
self.count = 0
self.metadata_store_set = set()
self.initialize(BasicRemoteQueryCommunity, 2)
self.torrent_template = {"title": "", "infohash": b"", "torrent_date": datetime(1970, 1, 1), "tags": "video"}

async def tearDown(self):
for metadata_store in self.metadata_store_set:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from json import dumps
from unittest.mock import AsyncMock, Mock, PropertyMock, patch

from ipv8.keyvault.crypto import default_eccrypto
from ipv8.test.base import TestBase

from pony.orm import db_session

from tribler_core.components.metadata_store.db.orm_bindings.channel_node import NEW
from tribler_core.components.metadata_store.db.store import MetadataStore
from tribler_core.components.metadata_store.remote_query_community.remote_query_community import RemoteQueryCommunity
from tribler_core.components.metadata_store.remote_query_community.settings import RemoteQueryCommunitySettings
from tribler_core.components.metadata_store.remote_query_community.tests.test_remote_query_community import (
BasicRemoteQueryCommunity,
)
from tribler_core.components.tag.db.tag_db import TagDatabase
from tribler_core.components.tag.db.tests.test_tag_db import Tag, TestTagDB
from tribler_core.utilities.path_util import Path


class TestRemoteSearchByTags(TestBase):
""" In this test set we will use only one node's instance as it is sufficient
for testing remote search by tags
"""

def setUp(self):
super().setUp()
self.metadata_store = None
self.tags_db = None
self.initialize(BasicRemoteQueryCommunity, 1)

async def tearDown(self):
if self.metadata_store:
self.metadata_store.shutdown()
if self.tags_db:
self.tags_db.shutdown()

await super().tearDown()

def create_node(self, *args, **kwargs):
self.metadata_store = MetadataStore(
Path(self.temporary_directory()) / "mds.db",
Path(self.temporary_directory()),
default_eccrypto.generate_key("curve25519"),
disable_sync=True,
)
self.tags_db = TagDatabase(str(Path(self.temporary_directory()) / "tags.db"))

kwargs['metadata_store'] = self.metadata_store
kwargs['tags_db'] = self.tags_db
kwargs['rqc_settings'] = RemoteQueryCommunitySettings()
return super().create_node(*args, **kwargs)

@property
def rqc(self) -> RemoteQueryCommunity:
return self.overlay(0)

@patch.object(RemoteQueryCommunity, 'tags_db', new=PropertyMock(return_value=None), create=True)
async def test_search_for_tags_no_db(self):
# test that in case of missed `tags_db`, function `search_for_tags` returns None
assert self.rqc.search_for_tags(tags=['tag']) is None

@patch.object(TagDatabase, 'get_infohashes')
async def test_search_for_tags_only_valid_tags(self, mocked_get_infohashes: Mock):
# test that function `search_for_tags` uses only valid tags
self.rqc.search_for_tags(tags=['invalid tag', 'valid_tag'])
mocked_get_infohashes.assert_called_with({'valid_tag'})

@patch.object(MetadataStore, 'get_entries_threaded', new_callable=AsyncMock)
async def test_process_rpc_query_no_tags(self, mocked_get_entries_threaded: AsyncMock):
# test that in case of missed tags, the remote search works like normal remote search
parameters = {'first': 0, 'infohash_set': None, 'last': 100}
json = dumps(parameters).encode('utf-8')

await self.rqc.process_rpc_query(json)

expected_parameters = {'infohash_set': None}
expected_parameters.update(parameters)
mocked_get_entries_threaded.assert_called_with(**expected_parameters)

async def test_process_rpc_query_with_tags(self):
# This is full test that checked whether search by tags works or not
#
# Test assumes that two databases were filled by the following data (TagsDatabase and MDS):
@db_session
def fill_tags_database():
TestTagDB.add_operation_set(
self.rqc.tags_db,
{
b'infohash1': [
Tag(name='tag1', count=2),
],
b'infohash2': [
Tag(name='tag2', count=1),
]
})

@db_session
def fill_mds():
with db_session:
def _add(infohash):
torrent = {"infohash": infohash, "title": 'title', "tags": "", "size": 1, "status": NEW}
self.rqc.mds.TorrentMetadata.from_dict(torrent)

_add(b'infohash1')
_add(b'infohash2')
_add(b'infohash3')

fill_tags_database()
fill_mds()

# Then we try to query search for three tags: 'tag1', 'tag2', 'tag3'
parameters = {'first': 0, 'infohash_set': None, 'last': 100, 'tags': ['tag1', 'tag2', 'tag3']}
json = dumps(parameters).encode('utf-8')

with db_session:
query_results = [r.to_dict() for r in await self.rqc.process_rpc_query(json)]

# Expected results: only one infohash (b'infohash1') should be returned.
result_infohash_list = [r['infohash'] for r in query_results]
assert result_infohash_list == [b'infohash1']
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def sanitize_parameters(self, parameters):
)
@querystring_schema(RemoteQueryParameters)
async def create_remote_search_request(self, request):
self._logger.info('Create remote search request')
# Query remote results from the GigaChannel Community.
# Results are returned over the Events endpoint.
try:
sanitized = self.sanitize_parameters(request.query)
except (ValueError, KeyError) as e:
return RESTResponse({"error": f"Error processing request parameters: {e}"}, status=HTTP_BAD_REQUEST)
self._logger.info(f'Parameters: {sanitized}')

request_uuid, peers_list = self.gigachannel_community.send_search_request(**sanitized)
peers_mid_list = [hexlify(p.mid) for p in peers_list]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def search_db():
try:
with db_session:
if tags:
lower_tags = {tag.lower() for tag in tags}
infohash_set = self.tags_db.get_infohashes(lower_tags)
infohash_set = self.tags_db.get_infohashes(set(tags))
sanitized['infohash_set'] = infohash_set

search_results, total, max_rowid = await mds.run_threaded(search_db)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,13 @@ def validate_tag(tag: str):
raise ValueError('Tag should not contain any spaces')


def is_valid_tag(tag: str) -> bool:
try:
validate_tag(tag)
except ValueError:
return False
return True


def validate_operation(operation: int):
TagOperationEnum(operation)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from tribler_core.components.tag.community.tag_payload import TagOperationEnum
from tribler_core.components.tag.community.tag_validator import validate_operation, validate_tag
from tribler_core.components.tag.community.tag_validator import is_valid_tag, validate_operation, validate_tag

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -49,3 +49,10 @@ async def test_contains_upper_case_not_latin():
async def test_contain_any_space():
with pytest.raises(ValueError):
validate_tag('tag with space')


async def test_is_valid_tag():
# test that is_valid_tag works similar to validate_tag but it returns `bool`
# instead of raise the ValueError exception
assert is_valid_tag('valid-tag')
assert not is_valid_tag('invalid tag')
Loading