From f9c0ef9f5adae989b743f0347cf3233eed0899dd Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Fri, 5 May 2017 18:50:20 -0400 Subject: [PATCH 1/4] Update helper.py Hi @dennybritz, I had to make these proposed changes to make the ScheduledOutputSampling work. I was getting `Can't convert bool to float 32` error before adding `maybe_concatenate_auxiliary_inputs(outputs).cell_output` to line 420. I also got `InvalidArgumentError: TensorArray dtype is int32 but Op is trying to write dtype bool.` before making the changes to line 387, 388 and the casting in 420. Please advice if the changes are valid. --- seq2seq/contrib/seq2seq/helper.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/seq2seq/contrib/seq2seq/helper.py b/seq2seq/contrib/seq2seq/helper.py index 977d0ab9..044311f2 100644 --- a/seq2seq/contrib/seq2seq/helper.py +++ b/seq2seq/contrib/seq2seq/helper.py @@ -384,10 +384,8 @@ def initialize(self, name=None): def sample(self, time, outputs, state, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", [time, outputs, state]): - sampler = bernoulli.Bernoulli(probs=self._sampling_probability) - return math_ops.cast( - sampler.sample(sample_shape=self.batch_size, seed=self._seed), - dtypes.bool) + sampler = bernoulli.Bernoulli(probs=self._sampling_probability, dtype=dtypes.int32) + return sampler.sample(sample_shape=self.batch_size, seed=self._seed) def next_inputs(self, time, outputs, state, sample_ids, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", @@ -419,7 +417,7 @@ def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): if self._next_input_layer is None: return array_ops.where( - sample_ids, maybe_concatenate_auxiliary_inputs(outputs), + math_ops.cast(sample_ids, dtypes.bool), maybe_concatenate_auxiliary_inputs(outputs).cell_output, base_next_inputs) where_sampling = math_ops.cast( From 16f4821dd432fd2ae81f43852b5f113db946db4a Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Sat, 6 May 2017 18:13:11 -0400 Subject: [PATCH 2/4] Update hooks_test.py Adding changes from @graehls pull request 'cause tests were failing. --- seq2seq/test/hooks_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index dedc6594..d7dfcc9a 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -39,16 +39,16 @@ class TestPrintModelAnalysisHook(tf.test.TestCase): def test_begin(self): model_dir = tempfile.mkdtemp() outfile = tempfile.NamedTemporaryFile() - tf.get_variable("weigths", [128, 128]) + tf.get_variable("weights", [128, 128]) hook = hooks.PrintModelAnalysisHook( params={}, model_dir=model_dir, run_config=tf.contrib.learn.RunConfig()) hook.begin() with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file: - file_contents = file.read().strip() + file_contents = tf.compat.as_text(file.read()).strip() self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n" - " weigths (128x128, 16.38k/16.38k params)") + " weights (128x128, 16.38k/16.38k params)") outfile.close() @@ -108,7 +108,7 @@ def test_sampling(self): outfile = os.path.join(self.sample_dir, "samples_000010.txt") with open(outfile, "rb") as readfile: self.assertIn("Prediction followed by Target @ Step 10", - readfile.read().decode("utf-8")) + tf.compat.as_text(readfile.read()).decode("utf-8")) class TestMetadataCaptureHook(tf.test.TestCase): @@ -125,7 +125,7 @@ def tearDown(self): def test_capture(self): global_step = tf.contrib.framework.get_or_create_global_step() # Some test computation - some_weights = tf.get_variable("weigths", [2, 128]) + some_weights = tf.get_variable("weights", [2, 128]) computation = tf.nn.softmax(some_weights) hook = hooks.MetadataCaptureHook( From d9119609feeea1eecc965681276ec8bc62eb1482 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Sat, 6 May 2017 18:35:18 -0400 Subject: [PATCH 3/4] Update hooks_test.py Adding changes from @graehls pull request 'cause tests were failing. --- seq2seq/test/hooks_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index d7dfcc9a..c49ea884 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -94,7 +94,7 @@ def test_sampling(self): outfile = os.path.join(self.sample_dir, "samples_000000.txt") with open(outfile, "rb") as readfile: self.assertIn("Prediction followed by Target @ Step 0", - readfile.read().decode("utf-8")) + tf.compat.as_text(readfile.read())) # Should not trigger for step 9 sess.run(tf.assign(global_step, 9)) @@ -108,7 +108,7 @@ def test_sampling(self): outfile = os.path.join(self.sample_dir, "samples_000010.txt") with open(outfile, "rb") as readfile: self.assertIn("Prediction followed by Target @ Step 10", - tf.compat.as_text(readfile.read()).decode("utf-8")) + tf.compat.as_text(readfile.read())) class TestMetadataCaptureHook(tf.test.TestCase): From aa3cf5ad7c9a1cc74e260024f7e5b5e833bba3f2 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Sat, 6 May 2017 18:47:56 -0400 Subject: [PATCH 4/4] Update hooks_test.py Adding changes from @graehl's pull request --- seq2seq/test/hooks_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index c49ea884..1537990f 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -47,7 +47,7 @@ def test_begin(self): with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file: file_contents = tf.compat.as_text(file.read()).strip() - self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n" + self.assertEqual(file_contents, "_TFProfRoot (--/16.38k params)\n" " weights (128x128, 16.38k/16.38k params)") outfile.close()