Skip to content

Commit

Permalink
be a little tricky with TF native
Browse files Browse the repository at this point in the history
we cant switch it on with TF 1.x
  • Loading branch information
dpressel committed Jul 14, 2021
1 parent 203bf13 commit 8563201
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions mead/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,13 @@ def initialize(self, embeddings_index):
def _create_task_specific_reader(self, vecs_set=None):
self._create_vectorizers(vecs_set)
reader_params = self.config_params['reader'] if 'reader' in self.config_params else self.config_params['loader']
reader_params['backend'] = self.backend.name
if self.backend.name == 'tf' and self.backend.params is not None and not bool(self.backend.params.get('prefer_eager', True)):
reader_params['backend'] = None
else:
reader_params['backend'] = self.backend.name
reader_params['clean_fn'] = reader_params.get('clean_fn', self.config_params.get('preproc', {}).get('clean_fn'))
if reader_params['clean_fn'] is not None and self.config_params['dataset'] != 'SST2':
logger.warning('Warning: A reader preprocessing function (%s) is active, it is recommended that all data preprocessing is done outside of baseline to insure data at inference time matches data at training time.', reader_params['clean_fn'])
logger.warning('Warning: A reader preprocessing function (%s) is active, it is recommended that all data preprocessing is done outside of baseline to ensure data at inference time matches data at training time.', reader_params['clean_fn'])
reader_params['mxlen'] = self.vectorizers[self.primary_key].mxlen
if self.config_params['train'].get('gpus', 1) > 1:
reader_params['truncate'] = True
Expand Down Expand Up @@ -398,7 +401,7 @@ def _create_embeddings(self, embeddings_set, vocabs, features):

# Also, if we are in eager mode, we might have to place the embeddings explicitly on the CPU
embeddings_section['cpu_placement'] = bool(embeddings_section.get('cpu_placement', False))
if self.backend.params is not None:
if self.backend.name == 'tf' and self.backend.params is not None:
# If we are in eager mode
if bool(self.backend.params.get('prefer_eager', False)):
train_block = self.config_params['train']
Expand Down Expand Up @@ -964,9 +967,14 @@ def _create_task_specific_reader(self, vecs_set=None):
reader_params['nctx'] = reader_params.get('nctx', reader_params.get('nbptt', self.config_params.get('nctx', self.config_params.get('nbptt', 35))))
reader_params['clean_fn'] = reader_params.get('clean_fn', self.config_params.get('preproc', {}).get('clean_fn'))
if reader_params['clean_fn'] is not None and self.config_params['dataset'] != 'SST2':
logger.warning('Warning: A reader preprocessing function (%s) is active, it is recommended that all data preprocessing is done outside of baseline to insure data at inference time matches data at training time.', reader_params['clean_fn'])
logger.warning('Warning: A reader preprocessing function (%s) is active, it is recommended that all data preprocessing is done outside of baseline to ensure data at inference time matches data at training time.', reader_params['clean_fn'])
reader_params['mxlen'] = self.vectorizers[self.primary_key].mxlen
reader_params['backend'] = self.backend.name

if self.backend.name == 'tf' and self.backend.params is not None and not bool(self.backend.params.get('prefer_eager', True)):
reader_params['backend'] = None
else:
reader_params['backend'] = self.backend.name

if self.config_params['train'].get('gpus', 1) > 1:
reader_params['truncate'] = True
return baseline.reader.create_reader(self.task_name(), self.vectorizers, self.config_params.get('preproc', {}).get('trim', False), **reader_params)
Expand Down

0 comments on commit 8563201

Please sign in to comment.