diff --git a/s3file/forms.py b/s3file/forms.py index 0565602..21e7976 100644 --- a/s3file/forms.py +++ b/s3file/forms.py @@ -1,7 +1,9 @@ import base64 +import html import logging import pathlib import uuid +from html.parser import HTMLParser from django.conf import settings from django.templatetags.static import static @@ -16,6 +18,71 @@ logger = logging.getLogger("s3file") +class InputToS3FileRewriter(HTMLParser): + """HTML parser that rewrites to custom elements.""" + + def __init__(self): + super().__init__() + self.output = [] + + def handle_starttag(self, tag, attrs): + if tag == "input" and dict(attrs).get("type") == "file": + self.output.append("") + else: + self.output.append(self.get_starttag_text()) + + def handle_endtag(self, tag): + self.output.append(f"") + + def handle_data(self, data): + self.output.append(data) + + def handle_startendtag(self, tag, attrs): + if tag == "input" and dict(attrs).get("type") == "file": + self.output.append("") + else: + self.output.append(self.get_starttag_text()) + + def handle_comment(self, data): + # Preserve HTML comments in the output + self.output.append(f"") + + def handle_decl(self, decl): + # Preserve declarations such as in the output + self.output.append(f"") + + def handle_pi(self, data): + # Preserve processing instructions such as in the output + self.output.append(f"") + + def handle_entityref(self, name): + # Preserve HTML entities like &, <, > + self.output.append(f"&{name};") + + def handle_charref(self, name): + # Preserve character references like ', ' + self.output.append(f"&#{name};") + + def get_html(self): + return "".join(self.output) + + @html_safe class Asset: """A generic asset that can be included in a template.""" @@ -99,11 +166,10 @@ def build_attrs(self, *args, **kwargs): def render(self, name, value, attrs=None, renderer=None): """Render the widget as a custom element for Safari compatibility.""" - return mark_safe( # noqa: S308 - str(super().render(name, value, attrs=attrs, renderer=renderer)).replace( - f'' +class TestInputToS3FileRewriter: + def test_transforms_file_input(self): + parser = forms.InputToS3FileRewriter() + parser.feed('') + assert parser.get_html() == '' + + def test_preserves_non_file_input(self): + parser = forms.InputToS3FileRewriter() + parser.feed('') + assert parser.get_html() == '' + + def test_handles_attribute_ordering(self): + parser = forms.InputToS3FileRewriter() + parser.feed('') + result = parser.get_html() + assert result.startswith("' + ) + result = parser.get_html() + assert result.startswith("') + result = parser.get_html() + assert 'data-value="test&value"' in result + + def test_preserves_existing_html_entities(self): + # Test that already-escaped entities in input are preserved (not double-escaped) + parser = forms.InputToS3FileRewriter() + parser.feed('') + result = parser.get_html() + # Should preserve the & entity, not convert to &amp; + assert 'data-value="test&value"' in result + assert '&amp;' not in result + + def test_preserves_character_references(self): + # Test that character references are preserved (may be in decimal or hex format) + parser = forms.InputToS3FileRewriter() + parser.feed('') + result = parser.get_html() + # The character reference should be preserved (either ' or ' both represent ') + assert ('data-value="test's"' in result or 'data-value="test's"' in result) + # Verify the actual apostrophe character is NOT directly in the output (should be a reference) + assert 'data-value="test\'s"' not in result or '&#' in result + + def test_handles_self_closing_tag(self): + parser = forms.InputToS3FileRewriter() + parser.feed('') + assert parser.get_html() == '' + + def test_preserves_non_file_self_closing_tag(self): + parser = forms.InputToS3FileRewriter() + parser.feed('') + assert parser.get_html() == '' + + def test_preserves_surrounding_elements(self): + parser = forms.InputToS3FileRewriter() + parser.feed('

') + result = parser.get_html() + assert result == '

' + + def test_preserves_html_comments(self): + parser = forms.InputToS3FileRewriter() + parser.feed('') + result = parser.get_html() + assert result == '' + + def test_preserves_declarations(self): + parser = forms.InputToS3FileRewriter() + parser.feed('') + result = parser.get_html() + assert result == '' + + def test_preserves_processing_instructions(self): + parser = forms.InputToS3FileRewriter() + parser.feed('') + result = parser.get_html() + assert result == '' + + @contextmanager def wait_for_page_load(driver, timeout=30): old_page = driver.find_element(By.TAG_NAME, "html") @@ -186,6 +278,21 @@ def test_render_wraps_in_s3_file_element(self, freeze_upload_folder): # Check that the output is the s3-file custom element assert html.startswith("