diff --git a/s3file/forms.py b/s3file/forms.py
index 0565602..21e7976 100644
--- a/s3file/forms.py
+++ b/s3file/forms.py
@@ -1,7 +1,9 @@
import base64
+import html
import logging
import pathlib
import uuid
+from html.parser import HTMLParser
from django.conf import settings
from django.templatetags.static import static
@@ -16,6 +18,71 @@
logger = logging.getLogger("s3file")
+class InputToS3FileRewriter(HTMLParser):
+ """HTML parser that rewrites to custom elements."""
+
+ def __init__(self):
+ super().__init__()
+ self.output = []
+
+ def handle_starttag(self, tag, attrs):
+ if tag == "input" and dict(attrs).get("type") == "file":
+ self.output.append("")
+ else:
+ self.output.append(self.get_starttag_text())
+
+ def handle_endtag(self, tag):
+ self.output.append(f"{tag}>")
+
+ def handle_data(self, data):
+ self.output.append(data)
+
+ def handle_startendtag(self, tag, attrs):
+ if tag == "input" and dict(attrs).get("type") == "file":
+ self.output.append("")
+ else:
+ self.output.append(self.get_starttag_text())
+
+ def handle_comment(self, data):
+ # Preserve HTML comments in the output
+ self.output.append(f"")
+
+ def handle_decl(self, decl):
+ # Preserve declarations such as in the output
+ self.output.append(f"")
+
+ def handle_pi(self, data):
+ # Preserve processing instructions such as in the output
+ self.output.append(f"{data}>")
+
+ def handle_entityref(self, name):
+ # Preserve HTML entities like &, <, >
+ self.output.append(f"&{name};")
+
+ def handle_charref(self, name):
+ # Preserve character references like ', '
+ self.output.append(f"{name};")
+
+ def get_html(self):
+ return "".join(self.output)
+
+
@html_safe
class Asset:
"""A generic asset that can be included in a template."""
@@ -99,11 +166,10 @@ def build_attrs(self, *args, **kwargs):
def render(self, name, value, attrs=None, renderer=None):
"""Render the widget as a custom element for Safari compatibility."""
- return mark_safe( # noqa: S308
- str(super().render(name, value, attrs=attrs, renderer=renderer)).replace(
- f''
+class TestInputToS3FileRewriter:
+ def test_transforms_file_input(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ assert parser.get_html() == ''
+
+ def test_preserves_non_file_input(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ assert parser.get_html() == ''
+
+ def test_handles_attribute_ordering(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ result = parser.get_html()
+ assert result.startswith("'
+ )
+ result = parser.get_html()
+ assert result.startswith("')
+ result = parser.get_html()
+ assert 'data-value="test&value"' in result
+
+ def test_preserves_existing_html_entities(self):
+ # Test that already-escaped entities in input are preserved (not double-escaped)
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ result = parser.get_html()
+ # Should preserve the & entity, not convert to &
+ assert 'data-value="test&value"' in result
+ assert '&' not in result
+
+ def test_preserves_character_references(self):
+ # Test that character references are preserved (may be in decimal or hex format)
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ result = parser.get_html()
+ # The character reference should be preserved (either ' or ' both represent ')
+ assert ('data-value="test's"' in result or 'data-value="test's"' in result)
+ # Verify the actual apostrophe character is NOT directly in the output (should be a reference)
+ assert 'data-value="test\'s"' not in result or '' in result
+
+ def test_handles_self_closing_tag(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ assert parser.get_html() == ''
+
+ def test_preserves_non_file_self_closing_tag(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ assert parser.get_html() == ''
+
+ def test_preserves_surrounding_elements(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ result = parser.get_html()
+ assert result == '
'
+
+ def test_preserves_html_comments(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ result = parser.get_html()
+ assert result == ''
+
+ def test_preserves_declarations(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ result = parser.get_html()
+ assert result == ''
+
+ def test_preserves_processing_instructions(self):
+ parser = forms.InputToS3FileRewriter()
+ parser.feed('')
+ result = parser.get_html()
+ assert result == ''
+
+
@contextmanager
def wait_for_page_load(driver, timeout=30):
old_page = driver.find_element(By.TAG_NAME, "html")
@@ -186,6 +278,21 @@ def test_render_wraps_in_s3_file_element(self, freeze_upload_folder):
# Check that the output is the s3-file custom element
assert html.startswith("