00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #include <stdio.h>
00031 #include <stdlib.h>
00032 #include <unistd.h>
00033 #include <sys/types.h>
00034 #include <sys/time.h>
00035 #include <sys/socket.h>
00036 #include <netinet/in.h>
00037 #include <netdb.h>
00038 #include <errno.h>
00039
00040
00041 #include "../../mem/mem.h"
00042
00043
00044 #include "../../dprint.h"
00045 #include "../../str.h"
00046
00047
00048 #include "auth_diameter.h"
00049 #include "defs.h"
00050 #include "tcp_comm.h"
00051 #include "diameter_msg.h"
00052
00053 #define MAX_TRIES 10
00054
00055
00056 int init_mytcp(char* host, int port)
00057 {
00058 int sockfd;
00059 struct sockaddr_in serv_addr;
00060 struct hostent *server;
00061
00062 sockfd = socket(PF_INET, SOCK_STREAM, 0);
00063
00064 if (sockfd < 0)
00065 {
00066 LM_ERR("error creating the socket\n");
00067 return -1;
00068 }
00069
00070 server = gethostbyname(host);
00071 if (server == NULL)
00072 {
00073 LM_ERR("error finding the host\n");
00074 return -1;
00075 }
00076
00077 memset((char *) &serv_addr, 0, sizeof(serv_addr));
00078 serv_addr.sin_family = PF_INET;
00079 memcpy((char *)&serv_addr.sin_addr.s_addr, (char *)server->h_addr,
00080 server->h_length);
00081 serv_addr.sin_port = htons(port);
00082
00083 if (connect(sockfd, (const struct sockaddr *)&serv_addr,
00084 sizeof(serv_addr)) < 0)
00085 {
00086 LM_ERR("error connecting to the "
00087 "DIAMETER client\n");
00088 return -1;
00089 }
00090
00091 return sockfd;
00092 }
00093
00094
00095
00096 void reset_read_buffer(rd_buf_t *rb)
00097 {
00098 rb->ret_code = 0;
00099 rb->chall_len = 0;
00100 if(rb->chall)
00101 pkg_free(rb->chall);
00102 rb->chall = 0;
00103
00104 rb->first_4bytes = 0;
00105 rb->buf_len = 0;
00106 if(rb->buf)
00107 pkg_free(rb->buf);
00108 rb->buf = 0;
00109 }
00110
00111
00112 int do_read( int socket, rd_buf_t *p)
00113 {
00114 unsigned char *ptr;
00115 unsigned int wanted_len, len;
00116 int n;
00117
00118 if (p->buf==0)
00119 {
00120 wanted_len = sizeof(p->first_4bytes) - p->buf_len;
00121 ptr = ((unsigned char*)&(p->first_4bytes)) + p->buf_len;
00122 }
00123 else
00124 {
00125 wanted_len = p->first_4bytes - p->buf_len;
00126 ptr = p->buf + p->buf_len;
00127 }
00128
00129 while( (n=recv( socket, ptr, wanted_len, MSG_DONTWAIT ))>0 )
00130 {
00131
00132 p->buf_len += n;
00133 if (n<wanted_len)
00134 {
00135
00136 wanted_len -= n;
00137 ptr += n;
00138 }
00139 else
00140 {
00141 if (p->buf==0)
00142 {
00143
00144 len = ntohl(p->first_4bytes)&0x00ffffff;
00145 if (len<AAA_MSG_HDR_SIZE || len>MAX_AAA_MSG_SIZE)
00146 {
00147 LM_ERR(" (sock=%d): invalid message "
00148 "length read %u (%x)\n", socket, len, p->first_4bytes);
00149 goto error;
00150 }
00151
00152 if ( (p->buf=pkg_malloc(len))==0 )
00153 {
00154 LM_ERR("no more pkg memory\n");
00155 goto error;
00156 }
00157 *((unsigned int*)p->buf) = p->first_4bytes;
00158 p->buf_len = sizeof(p->first_4bytes);
00159 p->first_4bytes = len;
00160
00161 ptr = p->buf + p->buf_len;
00162 wanted_len = p->first_4bytes - p->buf_len;
00163 }
00164 else
00165 {
00166
00167 LM_DBG("(sock=%d): whole message read (len=%d)!\n",
00168 socket, p->first_4bytes);
00169 return CONN_SUCCESS;
00170 }
00171 }
00172 }
00173
00174 if (n==0)
00175 {
00176 LM_INFO("(sock=%d): FIN received\n", socket);
00177 return CONN_CLOSED;
00178 }
00179 if ( n==-1 && errno!=EINTR && errno!=EAGAIN )
00180 {
00181 LM_ERR(" (sock=%d): n=%d , errno=%d (%s)\n",
00182 socket, n, errno, strerror(errno));
00183 goto error;
00184 }
00185 error:
00186 return CONN_ERROR;
00187 }
00188
00189
00190
00191 int tcp_send_recv(int sockfd, char* buf, int len, rd_buf_t* rb,
00192 unsigned int waited_id)
00193 {
00194 int n, number_of_tries;
00195 fd_set active_fd_set, read_fd_set;
00196 struct timeval tv;
00197 unsigned long int result_code;
00198 AAAMessage *msg;
00199 AAA_AVP *avp;
00200 char serviceType;
00201 unsigned int m_id;
00202
00203
00204 while( (n=write(sockfd, buf, len))==-1 )
00205 {
00206 if (errno==EINTR)
00207 continue;
00208 LM_ERR("write returned error: %s\n", strerror(errno));
00209 return AAA_ERROR;
00210 }
00211
00212 if (n!=len)
00213 {
00214 LM_ERR("write gave no error but wrote less than asked\n");
00215 return AAA_ERROR;
00216 }
00217
00218
00219 tv.tv_sec = MAX_WAIT_SEC;
00220 tv.tv_usec = MAX_WAIT_USEC;
00221
00222
00223 FD_ZERO (&active_fd_set);
00224 FD_SET (sockfd, &active_fd_set);
00225 number_of_tries = 0;
00226
00227 while(number_of_tries<MAX_TRIES)
00228 {
00229 read_fd_set = active_fd_set;
00230 if (select (sockfd+1, &read_fd_set, NULL, NULL, &tv) < 0)
00231 {
00232 LM_ERR("select function failed\n");
00233 return AAA_ERROR;
00234 }
00235
00236
00237
00238
00239
00240
00241
00242
00243 reset_read_buffer(rb);
00244 switch( do_read(sockfd, rb) )
00245 {
00246 case CONN_ERROR:
00247 LM_ERR("error when trying to read from socket\n");
00248 return AAA_CONN_CLOSED;
00249 case CONN_CLOSED:
00250 LM_ERR("connection closed by diameter client!\n");
00251 return AAA_CONN_CLOSED;
00252 }
00253
00254
00255 msg = AAATranslateMessage(rb->buf, rb->buf_len, 0);
00256 if(!msg)
00257 {
00258 LM_ERR("message structure not obtained\n");
00259 return AAA_ERROR;
00260 }
00261 avp = AAAFindMatchingAVP(msg, NULL, AVP_SIP_MSGID,
00262 vendorID, AAA_FORWARD_SEARCH);
00263 if(!avp)
00264 {
00265 LM_ERR("AVP_SIP_MSGID not found\n");
00266 return AAA_ERROR;
00267 }
00268 m_id = *((unsigned int*)(avp->data.s));
00269 LM_DBG("######## m_id=%d\n", m_id);
00270 if(m_id!=waited_id)
00271 {
00272 number_of_tries ++;
00273 LM_NOTICE("old message received\n");
00274 continue;
00275 }
00276 goto next;
00277 }
00278
00279 LM_ERR("too many old messages received\n");
00280 return AAA_TIMEOUT;
00281 next:
00282
00283 avp = AAAFindMatchingAVP(msg, NULL, AVP_Service_Type,
00284 vendorID, AAA_FORWARD_SEARCH);
00285 if(!avp)
00286 {
00287 LM_ERR("AVP_Service_Type not found\n");
00288 return AAA_ERROR;
00289 }
00290 serviceType = avp->data.s[0];
00291
00292 result_code = ntohl(*((unsigned long int*)(msg->res_code->data.s)));
00293 switch(result_code)
00294 {
00295 case AAA_SUCCESS:
00296 rb->ret_code = AAA_AUTHORIZED;
00297 break;
00298 case AAA_AUTHENTICATION_REJECTED:
00299 if(serviceType!=SIP_AUTH_SERVICE)
00300 {
00301 rb->ret_code = AAA_NOT_AUTHORIZED;
00302 break;
00303 }
00304 avp = AAAFindMatchingAVP(msg, NULL, AVP_Challenge,
00305 vendorID, AAA_FORWARD_SEARCH);
00306 if(!avp)
00307 {
00308 LM_ERR("AVP_Response not found\n");
00309 rb->ret_code = AAA_SRVERR;
00310 break;
00311 }
00312 rb->chall_len=avp->data.len;
00313 rb->chall = (unsigned char*)pkg_malloc(avp->data.len*sizeof(char));
00314 if(rb->chall == NULL)
00315 {
00316 LM_ERR("no more pkg memory\n");
00317 rb->ret_code = AAA_SRVERR;
00318 break;
00319 }
00320 memcpy(rb->chall, avp->data.s, avp->data.len);
00321 rb->ret_code = AAA_CHALENGE;
00322 break;
00323 case AAA_AUTHORIZATION_REJECTED:
00324 rb->ret_code = AAA_NOT_AUTHORIZED;
00325 break;
00326 default:
00327 rb->ret_code = AAA_SRVERR;
00328 }
00329
00330 return rb->ret_code;
00331 }
00332 void close_tcp_connection(int sfd)
00333 {
00334 shutdown(sfd, 2);
00335 }
00336
00337