#include "modbus.h"
#include "esphome/core/application.h"
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"

namespace esphome {
namespace modbus {

static const char *const TAG = "modbus";

// Maximum bytes to log for Modbus frames (truncated if larger)
static constexpr size_t MODBUS_MAX_LOG_BYTES = 64;

void Modbus::setup() {
  if (this->flow_control_pin_ != nullptr) {
    this->flow_control_pin_->setup();
  }
}
void Modbus::loop() {
  const uint32_t now = App.get_loop_component_start_time();

  // Read all available bytes in batches to reduce UART call overhead.
  size_t avail = this->available();
  uint8_t buf[64];
  while (avail > 0) {
    size_t to_read = std::min(avail, sizeof(buf));
    if (!this->read_array(buf, to_read)) {
      break;
    }
    avail -= to_read;

    for (size_t i = 0; i < to_read; i++) {
      if (this->parse_modbus_byte_(buf[i])) {
        this->last_modbus_byte_ = now;
      } else {
        size_t at = this->rx_buffer_.size();
        if (at > 0) {
          ESP_LOGV(TAG, "Clearing buffer of %d bytes - parse failed", at);
          this->rx_buffer_.clear();
        }
      }
    }
  }

  if (now - this->last_modbus_byte_ > 50) {
    size_t at = this->rx_buffer_.size();
    if (at > 0) {
      ESP_LOGV(TAG, "Clearing buffer of %d bytes - timeout", at);
      this->rx_buffer_.clear();
    }

    // stop blocking new send commands after sent_wait_time_ ms after response received
    if (now - this->last_send_ > send_wait_time_) {
      if (waiting_for_response > 0) {
        ESP_LOGV(TAG, "Stop waiting for response from %d", waiting_for_response);
      }
      waiting_for_response = 0;
    }
  }
}

bool Modbus::parse_modbus_byte_(uint8_t byte) {
  size_t at = this->rx_buffer_.size();
  this->rx_buffer_.push_back(byte);
  const uint8_t *raw = &this->rx_buffer_[0];
  ESP_LOGVV(TAG, "Modbus received Byte  %d (0X%x)", byte, byte);
  // Byte 0: modbus address (match all)
  if (at == 0)
    return true;
  uint8_t address = raw[0];
  uint8_t function_code = raw[1];
  // Byte 2: Size (with modbus rtu function code 4/3)
  // See also https://en.wikipedia.org/wiki/Modbus
  if (at == 2)
    return true;

  uint8_t data_len = raw[2];
  uint8_t data_offset = 3;

  // Per https://modbus.org/docs/Modbus_Application_Protocol_V1_1b3.pdf Ch 5 User-Defined function codes
  if (((function_code >= FUNCTION_CODE_USER_DEFINED_SPACE_1_INIT) &&
       (function_code <= FUNCTION_CODE_USER_DEFINED_SPACE_1_END)) ||
      ((function_code >= FUNCTION_CODE_USER_DEFINED_SPACE_2_INIT) &&
       (function_code <= FUNCTION_CODE_USER_DEFINED_SPACE_2_END))) {
    // Handle user-defined function, since we don't know how big this ought to be,
    // ideally we should delegate the entire length detection to whatever handler is
    // installed, but wait, there is the CRC, and if we get a hit there is a good
    // chance that this is a complete message ... admittedly there is a small chance is
    // isn't but that is quite small given the purpose of the CRC in the first place

    // Fewer than 2 bytes can't calc CRC
    if (at < 2)
      return true;

    data_len = at - 2;
    data_offset = 1;

    uint16_t computed_crc = crc16(raw, data_offset + data_len);
    uint16_t remote_crc = uint16_t(raw[data_offset + data_len]) | (uint16_t(raw[data_offset + data_len + 1]) << 8);

    if (computed_crc != remote_crc)
      return true;

    ESP_LOGD(TAG, "Modbus user-defined function %02X found", function_code);

  } else {
    // data starts at 2 and length is 4 for read registers commands
    if (this->role == ModbusRole::SERVER) {
      if (function_code == ModbusFunctionCode::READ_COILS ||
          function_code == ModbusFunctionCode::READ_DISCRETE_INPUTS ||
          function_code == ModbusFunctionCode::READ_HOLDING_REGISTERS ||
          function_code == ModbusFunctionCode::READ_INPUT_REGISTERS ||
          function_code == ModbusFunctionCode::WRITE_SINGLE_REGISTER) {
        data_offset = 2;
        data_len = 4;
      } else if (function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) {
        if (at < 6) {
          return true;
        }
        data_offset = 2;
        // starting address (2 bytes) + quantity of registers (2 bytes) + byte count itself (1 byte) + actual byte count
        data_len = 2 + 2 + 1 + raw[6];
      }
    } else {
      // the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands
      if (function_code == ModbusFunctionCode::WRITE_SINGLE_COIL ||
          function_code == ModbusFunctionCode::WRITE_SINGLE_REGISTER ||
          function_code == ModbusFunctionCode::WRITE_MULTIPLE_COILS ||
          function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) {
        data_offset = 2;
        data_len = 4;
      }
    }

    // Error ( msb indicates error )
    // response format:  Byte[0] = device address, Byte[1] function code | 0x80 , Byte[2] exception code, Byte[3-4] crc
    if ((function_code & FUNCTION_CODE_EXCEPTION_MASK) == FUNCTION_CODE_EXCEPTION_MASK) {
      data_offset = 2;
      data_len = 1;
    }

    // Byte data_offset..data_offset+data_len-1: Data
    if (at < data_offset + data_len)
      return true;

    // Byte 3+data_len: CRC_LO (over all bytes)
    if (at == data_offset + data_len)
      return true;

    // Byte data_offset+len+1: CRC_HI (over all bytes)
    uint16_t computed_crc = crc16(raw, data_offset + data_len);
    uint16_t remote_crc = uint16_t(raw[data_offset + data_len]) | (uint16_t(raw[data_offset + data_len + 1]) << 8);
    if (computed_crc != remote_crc) {
      if (this->disable_crc_) {
        ESP_LOGD(TAG, "Modbus CRC Check failed, but ignored! %02X!=%02X", computed_crc, remote_crc);
      } else {
        ESP_LOGW(TAG, "Modbus CRC Check failed! %02X!=%02X", computed_crc, remote_crc);
        return false;
      }
    }
  }
  std::vector<uint8_t> data(this->rx_buffer_.begin() + data_offset, this->rx_buffer_.begin() + data_offset + data_len);
  bool found = false;
  for (auto *device : this->devices_) {
    if (device->address_ == address) {
      found = true;
      // Is it an error response?
      if ((function_code & FUNCTION_CODE_EXCEPTION_MASK) == FUNCTION_CODE_EXCEPTION_MASK) {
        ESP_LOGD(TAG, "Modbus error function code: 0x%X exception: %d", function_code, raw[2]);
        if (waiting_for_response != 0) {
          device->on_modbus_error(function_code & FUNCTION_CODE_MASK, raw[2]);
        } else {
          // Ignore modbus exception not related to a pending command
          ESP_LOGD(TAG, "Ignoring Modbus error - not expecting a response");
        }
        continue;
      }
      if (this->role == ModbusRole::SERVER) {
        if (function_code == ModbusFunctionCode::READ_HOLDING_REGISTERS ||
            function_code == ModbusFunctionCode::READ_INPUT_REGISTERS) {
          device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8),
                                           uint16_t(data[3]) | (uint16_t(data[2]) << 8));
          continue;
        }
        if (function_code == ModbusFunctionCode::WRITE_SINGLE_REGISTER ||
            function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) {
          device->on_modbus_write_registers(function_code, data);
          continue;
        }
      }
      // fallthrough for other function codes
      device->on_modbus_data(data);
    }
  }
  waiting_for_response = 0;

  if (!found) {
    ESP_LOGW(TAG, "Got Modbus frame from unknown address 0x%02X! ", address);
  }

  // reset buffer
  ESP_LOGV(TAG, "Clearing buffer of %d bytes - parse succeeded", at);
  this->rx_buffer_.clear();
  return true;
}

void Modbus::dump_config() {
  ESP_LOGCONFIG(TAG,
                "Modbus:\n"
                "  Send Wait Time: %d ms\n"
                "  CRC Disabled: %s",
                this->send_wait_time_, YESNO(this->disable_crc_));
  LOG_PIN("  Flow Control Pin: ", this->flow_control_pin_);
}
float Modbus::get_setup_priority() const {
  // After UART bus
  return setup_priority::BUS - 1.0f;
}

void Modbus::send(uint8_t address, uint8_t function_code, uint16_t start_address, uint16_t number_of_entities,
                  uint8_t payload_len, const uint8_t *payload) {
  static const size_t MAX_VALUES = 128;

  // Only check max number of registers for standard function codes
  // Some devices use non standard codes like 0x43
  if (number_of_entities > MAX_VALUES && function_code <= ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) {
    ESP_LOGE(TAG, "send too many values %d max=%zu", number_of_entities, MAX_VALUES);
    return;
  }

  static constexpr size_t ADDR_SIZE = 1;
  static constexpr size_t FC_SIZE = 1;
  static constexpr size_t START_ADDR_SIZE = 2;
  static constexpr size_t NUM_ENTITIES_SIZE = 2;
  static constexpr size_t BYTE_COUNT_SIZE = 1;
  static constexpr size_t MAX_PAYLOAD_SIZE = std::numeric_limits<uint8_t>::max();
  static constexpr size_t CRC_SIZE = 2;
  static constexpr size_t MAX_FRAME_SIZE =
      ADDR_SIZE + FC_SIZE + START_ADDR_SIZE + NUM_ENTITIES_SIZE + BYTE_COUNT_SIZE + MAX_PAYLOAD_SIZE + CRC_SIZE;
  uint8_t data[MAX_FRAME_SIZE];
  size_t pos = 0;

  data[pos++] = address;
  data[pos++] = function_code;
  if (this->role == ModbusRole::CLIENT) {
    data[pos++] = start_address >> 8;
    data[pos++] = start_address >> 0;
    if (function_code != ModbusFunctionCode::WRITE_SINGLE_COIL &&
        function_code != ModbusFunctionCode::WRITE_SINGLE_REGISTER) {
      data[pos++] = number_of_entities >> 8;
      data[pos++] = number_of_entities >> 0;
    }
  }

  if (payload != nullptr) {
    if (this->role == ModbusRole::SERVER || function_code == ModbusFunctionCode::WRITE_MULTIPLE_COILS ||
        function_code == ModbusFunctionCode::WRITE_MULTIPLE_REGISTERS) {  // Write multiple
      data[pos++] = payload_len;                                          // Byte count is required for write
    } else {
      payload_len = 2;  // Write single register or coil
    }
    for (int i = 0; i < payload_len; i++) {
      data[pos++] = payload[i];
    }
  }

  auto crc = crc16(data, pos);
  data[pos++] = crc >> 0;
  data[pos++] = crc >> 8;

  if (this->flow_control_pin_ != nullptr)
    this->flow_control_pin_->digital_write(true);

  this->write_array(data, pos);
  this->flush();

  if (this->flow_control_pin_ != nullptr)
    this->flow_control_pin_->digital_write(false);
  waiting_for_response = address;
  last_send_ = millis();
#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE
  char hex_buf[format_hex_pretty_size(MODBUS_MAX_LOG_BYTES)];
#endif
  ESP_LOGV(TAG, "Modbus write: %s", format_hex_pretty_to(hex_buf, data, pos));
}

// Helper function for lambdas
// Send raw command. Except CRC everything must be contained in payload
void Modbus::send_raw(const std::vector<uint8_t> &payload) {
  if (payload.empty()) {
    return;
  }

  if (this->flow_control_pin_ != nullptr)
    this->flow_control_pin_->digital_write(true);

  auto crc = crc16(payload.data(), payload.size());
  this->write_array(payload);
  this->write_byte(crc & 0xFF);
  this->write_byte((crc >> 8) & 0xFF);
  this->flush();
  if (this->flow_control_pin_ != nullptr)
    this->flow_control_pin_->digital_write(false);
  waiting_for_response = payload[0];
#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE
  char hex_buf[format_hex_pretty_size(MODBUS_MAX_LOG_BYTES)];
#endif
  ESP_LOGV(TAG, "Modbus write raw: %s", format_hex_pretty_to(hex_buf, payload.data(), payload.size()));
  last_send_ = millis();
}

}  // namespace modbus
}  // namespace esphome
