diff --git a/lib/remote_input/downloader.rb b/lib/remote_input/downloader.rb index 3381a09..84030d6 100644 --- a/lib/remote_input/downloader.rb +++ b/lib/remote_input/downloader.rb @@ -12,26 +12,20 @@ module RemoteInput class Downloader class TooManyRedirects < Error; end - def initialize(url) - if url.is_a?(URI::Generic) - url = url.dup - else - url = URI.parse(url) - end - @url = url - unless @url.is_a?(URI::HTTP) - raise ArgumentError, "download URL must be HTTP or HTTPS: <#{@url}>" - end + def initialize(url, *fallback_urls, http_method: nil, http_parameters: nil) + @url = normalize_url(url) + @fallback_urls = fallback_urls.collect { |fallback_url| normalize_url(fallback_url) } + @http_method = http_method + @http_parameters = http_parameters end def download(output_path, &block) - if output_path.exist? - yield_chunks(output_path, &block) if block_given? - return - end + return if use_cache(output_path, &block) partial_output_path = Pathname.new("#{output_path}.partial") synchronize(output_path, partial_output_path) do + return if use_cache(output_path, &block) + output_path.parent.mkpath n_retries = 0 @@ -47,7 +41,7 @@ def download(output_path, &block) headers["Range"] = "bytes=#{start}-" end - start_http(@url, headers) do |response| + start_http(@url, @fallback_urls, headers) do |response| if response.is_a?(Net::HTTPPartialContent) mode = "ab" else @@ -87,6 +81,27 @@ def download(output_path, &block) end end + private def normalize_url(url) + if url.is_a?(URI::Generic) + url = url.dup + else + url = URI.parse(url) + end + unless url.is_a?(URI::HTTP) + raise ArgumentError, "download URL must be HTTP or HTTPS: <#{url}>" + end + url + end + + private def use_cache(output_path, &block) + if output_path.exist? + yield_chunks(output_path, &block) if block_given? + true + else + false + end + end + private def synchronize(output_path, partial_output_path) begin Process.getpgid(Process.pid) @@ -106,7 +121,8 @@ def download(output_path, &block) rescue ArgumentError # The process that acquired the lock will be exited before # it stores its process ID. - valid_lock_path = (lock_path.mtime > 10) + elapsed_time = Time.now - lock_path.mtime + valid_lock_path = (elapsed_time < 10) else begin Process.getpgid(pid) @@ -135,7 +151,7 @@ def download(output_path, &block) end end - private def start_http(url, headers, limit = 10, &block) + private def start_http(url, fallback_urls, headers, limit = 10, &block) if limit == 0 raise TooManyRedirects, "too many redirections: #{url}" end @@ -145,7 +161,27 @@ def download(output_path, &block) http.start do path = url.path path += "?#{url.query}" if url.query - request = Net::HTTP::Get.new(path, headers) + if @http_method == :post + # TODO: We may want to add @http_content_type, @http_body + # and so on. + if @http_parameters + body = URI.encode_www_form(@http_parameters) + content_type = "application/x-www-form-urlencoded" + headers = {"Content-Type" => content_type}.merge(headers) + else + body = "" + end + request = Net::HTTP::Post.new(path, headers) + request.body = body + else + request = Net::HTTP::Get.new(path, headers) + end + if url.scheme == "https" and url.host == "api.github.com" + gh_token = ENV["GH_TOKEN"] + if gh_token + headers = headers.merge("Authorization" => "Bearer #{gh_token}") + end + end http.request(request) do |response| case response when Net::HTTPSuccess, Net::HTTPPartialContent @@ -153,8 +189,19 @@ def download(output_path, &block) when Net::HTTPRedirection url = URI.parse(response[:location]) $stderr.puts "Redirect to #{url}" - return start_http(url, headers, limit - 1, &block) + return start_http(url, fallback_urls, headers, limit - 1, &block) else + case response + when Net::HTTPForbidden, Net::HTTPNotFound + next_url, *rest_fallback_urls = fallback_urls + if next_url + message = "#{response.code}: #{response.message}: " + + "fallback: <#{url}> -> <#{next_url}>" + $stderr.puts(message) + return start_http(next_url, rest_fallback_urls, headers, &block) + end + end + message = response.code if response.message and not response.message.empty? message += ": #{response.message}" @@ -166,11 +213,12 @@ def download(output_path, &block) end end - private def yield_chunks(path) - path.open("rb") do |output| + private def yield_chunks(path, &block) + return unless block_given? + + path.open("rb") do |input| chunk_size = 1024 * 1024 - chunk = "" - while output.read(chunk_size, chunk) + while chunk = input.read(chunk_size) yield(chunk) end end diff --git a/test/test-downloader.rb b/test/test-downloader.rb index 2f01f73..7048736 100644 --- a/test/test-downloader.rb +++ b/test/test-downloader.rb @@ -1,6 +1,43 @@ class DownloaderTest < Test::Unit::TestCase include Helper::Sandbox + sub_test_case("#initialize") do + test("single URL") do + url = "https://example.com/file" + downloader = RemoteInput::Downloader.new(url) + assert_equal(URI.parse(url), downloader.instance_variable_get(:@url)) + assert_equal([], downloader.instance_variable_get(:@fallback_urls)) + end + + test("with fallback URLs") do + url = "https://example.com/file" + fallback1 = "https://mirror1.example.com/file" + fallback2 = "https://mirror2.example.com/file" + downloader = RemoteInput::Downloader.new(url, fallback1, fallback2) + + assert_equal(URI.parse(url), downloader.instance_variable_get(:@url)) + assert_equal([URI.parse(fallback1), URI.parse(fallback2)], + downloader.instance_variable_get(:@fallback_urls)) + end + + test("with HTTP method and parameters") do + url = "https://example.com/api" + parameters = { key: "value", data: "test" } + downloader = RemoteInput::Downloader.new(url, + http_method: :post, + http_parameters: parameters) + + assert_equal(:post, downloader.instance_variable_get(:@http_method)) + assert_equal(parameters, downloader.instance_variable_get(:@http_parameters)) + end + + test("invalid URL") do + assert_raise(ArgumentError) do + RemoteInput::Downloader.new("ftp://example.com/file") + end + end + end + sub_test_case("#download") do def setup setup_sandbox @@ -17,7 +54,7 @@ def teardown output_path = @tmp_dir + "file" downloader = RemoteInput::Downloader.new(first_url) - downloader.define_singleton_method(:start_http) do |url, headers| + downloader.define_singleton_method(:start_http) do |url, fallback_urls, headers| raise RemoteInput::Downloader::TooManyRedirects, "too many redirections: #{last_url}" end @@ -25,5 +62,55 @@ def teardown downloader.download(output_path) end end + + test("use cache when file exists") do + output_path = @tmp_dir + "cached_file" + output_path.write("cached content") + + downloader = RemoteInput::Downloader.new("https://example.com/file") + + # Should not call start_http when file exists + downloader.define_singleton_method(:start_http) do |url, fallback_urls, headers| + flunk("start_http should not be called when file exists") + end + + downloader.download(output_path) + assert_equal("cached content", output_path.read) + end + + test("yield chunks when using cache") do + output_path = @tmp_dir + "cached_file" + content = "chunk1chunk2chunk3" + output_path.write(content) + + downloader = RemoteInput::Downloader.new("https://example.com/file") + + chunks = [] + downloader.download(output_path) do |chunk| + chunks << chunk + end + + assert_equal(content, chunks.join) + end + end + + sub_test_case("fallback URLs") do + def setup + setup_sandbox + end + + def teardown + teardown_sandbox + end + + test("fallback URLs are stored correctly") do + main_url = "https://example.com/file" + fallback_url = "https://mirror.example.com/file" + + downloader = RemoteInput::Downloader.new(main_url, fallback_url) + + fallback_urls = downloader.instance_variable_get(:@fallback_urls) + assert_equal([URI.parse(fallback_url)], fallback_urls) + end end end